Source code for cyto_dl.image.transforms.project

from typing import Union

from monai.transforms import Transform
from monai.transforms.utils_pytorch_numpy_unification import max as _max
from monai.transforms.utils_pytorch_numpy_unification import mean, median
from monai.transforms.utils_pytorch_numpy_unification import min as _min
from monai.transforms.utils_pytorch_numpy_unification import mode, std
from omegaconf import ListConfig


[docs]class Projectd(Transform): # codespell:ignore """Monai-style transform to apply projections (e.g., max, std) to an image.""" def __init__( self, keys: Union[list, str], projection_dim: int = 1, projection_type: str = "max", allow_missing_keys: bool = False, ): """ Parameters ---------- keys: Union[list, str] keys to apply projection projection_dim: int=1 index into C[Z]YX to compute projection across projection_type: str="max" Type of projection to apply. Options: "max", "min", "std", "median", "mode", "mean" allow_missing_keys: bool=False Whether to raise error if specified key is missing """ super().__init__() self.keys = keys if isinstance(keys, (list, ListConfig)) else [keys] self.projection_dim = projection_dim self.allow_missing_keys = allow_missing_keys projection_fns = { "max": _max, "min": _min, "std": std, "median": median, "mode": mode, "mean": mean, } if projection_type not in projection_fns: raise ValueError( f"Unsupported projection_type: {projection_type}. Supported types: {projection_fns.keys()}" ) self.projector = projection_fns[projection_type] def __call__(self, input_dict): """ Parameters ---------- input_dict: Dict[str, torch.Tensor] dict of C[Z]YX tensors """ for key in self.keys: if key in input_dict.keys(): input_dict[key] = self.projector(input_dict[key], dim=self.projection_dim) elif not self.allow_missing_keys: raise KeyError( f"key `{key}` not available. Available keys are {input_dict.keys()}" ) return input_dict