Source code for cyto_dl.image.transforms.project

from typing import Union

import torch
from monai.transforms import Transform
from omegaconf import ListConfig


[docs]class MaxProjectd(Transform): """Monai-style transform to take max projection of an image.""" def __init__( self, keys: Union[list, str], projection_dim: int = 1, allow_missing_keys: bool = False, ): """ Parameters ---------- keys: Union[list, str] keys to apply max projection projection_dim: int=1 index into NCZYX to compute projection across allow_missing_keys: bool=False Whether to raise error if specified key is missing """ super().__init__() self.keys = keys if isinstance(keys, (list, ListConfig)) else [keys] self.allow_missing_keys = allow_missing_keys self.projection_dim = projection_dim def __call__(self, input_dict): """ Parameters ---------- input_dict: Dict[str, torch.Tensor] dict of CZYX tensors/metadata """ for key in self.keys: if key in input_dict.keys(): input_dict[key], _ = torch.max(input_dict[key], dim=self.projection_dim) elif not self.allow_missing_keys: raise KeyError( f"key `{key}` not available. Available keys are {input_dict.keys()}" ) return input_dict