Source code for cyto_dl.image.transforms.clip
import torch
from monai.transforms import Transform
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile
from omegaconf import ListConfig
[docs]class Clip(Transform):
"""Transform for clipping image intensities based on absolute or percentile values."""
def __init__(self, low: float = 0.01, high: float = 99.99, percentile=True):
"""
Parameters
----------
low: float
lower bound for clipping
high: float
upper bound for clipping
percentile: bool
whether to use percentile or absolute values for clipping
"""
super().__init__()
self.low = low
self.high = high
self.percentile = percentile
def __call__(self, img):
low = self.low
high = self.high
if self.percentile:
low = percentile(img, low)
high = percentile(img, high)
return clip(img, low, high)
[docs]class Clipd(Transform):
"""Dictionary Transform for clipping image intensities based on absolute or percentile
values."""
def __init__(
self,
keys: str,
low: float = 00.01,
high: float = 99.99,
percentile=True,
allow_missing_keys: bool = False,
per_channel=True,
):
"""
Parameters
----------
keys: str
name of images to resize
low: float
lower bound for clipping
high: float
upper bound for clipping
percentile: bool
whether to use percentile or absolute values for clipping
allow_missing_keys: bool
whether to fail if provided keys are missing
"""
super().__init__()
self.keys = keys if isinstance(keys, (list, ListConfig)) else [keys]
self.allow_missing_keys = allow_missing_keys
self.clipper = Clip(low, high, percentile)
self.per_channel = per_channel
def __call__(self, img_dict):
for key in self.keys:
if key in img_dict.keys():
if self.per_channel:
img_dict[key] = torch.stack([self.clipper(img) for img in img_dict[key]])
else:
img_dict[key] = self.clipper(img_dict[key])
elif not self.allow_missing_keys:
raise KeyError(f"key `{key}` not available. Available keys are {img_dict.keys()}")
return img_dict