Source code for cyto_dl.image.transforms.pad

from typing import Dict, Optional, Union

import numpy as np
import torch
from monai.transforms import Transform
from monai.transforms.croppad.functional import pad_nd
from omegaconf import ListConfig, OmegaConf


[docs]class PadZd(Transform): """Transform for randomly padding top or bottom of crop by repeating first/last slice. Only applied if no segmentation is present in first/last slice """ def __init__( self, image_key: str, segmentation_key: str, pad_amount: Dict[str, int], pad_keys: Union[str, ListConfig] = [], segmentation_ch: Optional[int] = None, ): """ Parameters ---------- image_key: str name of images to pad segmentation_key: str name of segmentation. Used for checking whether top or bottom can be offset pad_amount: int number of slices to pad segmentation_ch: int channel of segmentation to check for presence of segmentation allow_missing_keys: bool allow missing keys """ super().__init__() self.image_key = image_key self.segmentation_key = segmentation_key self.pad_keys = [image_key, segmentation_key] + pad_keys self.pad_amount = OmegaConf.to_container(pad_amount) self.segmentation_ch = segmentation_ch def __call__(self, img_dict): image = img_dict[self.image_key] segmentation = img_dict[self.segmentation_key] if segmentation.shape[0] > 1 and self.segmentation_ch is None: raise ValueError( "segmentation_ch must be specified if segmentation has more than one channel" ) elif segmentation.shape[0] == 1: ch_seg = segmentation[0] else: ch_seg = segmentation[self.segmentation_ch] pad_mode = "replicate" if isinstance(image, torch.Tensor) else "edge" for key in self.pad_keys: pad = [(0, 0)] * 4 # high/low CZYX if (ch_seg[0] == 0).all(): pad[1] = (self.pad_amount[key], 0) if (ch_seg[-1] == 0).all(): # add upper padding pad[1] = np.max([pad[1], (0, self.pad_amount[key])], axis=1) img_dict[key] = pad_nd(img_dict[key], pad, mode=pad_mode) return img_dict