Source code for cyto_dl.models.im2im.utils.postprocessing.act_thresh_label

from typing import Callable

import numpy as np
import torch
from hydra.utils import get_class
from numpy.typing import DTypeLike
from skimage.exposure import rescale_intensity
from skimage.measure import label


[docs]class ActThreshLabel: """General-purpose postprocessing transform for applying any of an activation, threshold, labeling, and rescaling.""" def __init__( self, activation: Callable = torch.nn.Identity(), threshold: float = None, label: bool = False, dtype: DTypeLike = np.uint8, ch: int = -1, rescale_dtype: DTypeLike = None, ): """ Parameters ---------- activation: Callable=torch.nn.Identity() activation to apply to image threshold: float =None Threshold value, defaults to None, where no threshold is applied label:bool=False whether to label image after thresholding dtype:DTypeLike=np.uint8 data type of output image, defaults to np.uint8 ch:int=-1 channel of image to apply postprocessing to, default -1 for all channels rescale_dtype=None dtype to rescale intensity range to, defaults to no rescaling. """ self.activation = activation self.threshold = threshold self.label = label self.dtype = self._get_dtype(dtype) self.ch = ch self.rescale_dtype = self._get_dtype(rescale_dtype) if self.rescale_dtype is not None: self.dtype = self.rescale_dtype def _get_dtype(self, dtype: DTypeLike) -> DTypeLike: if isinstance(dtype, str): return get_class(dtype) elif dtype is None: return dtype elif isinstance(dtype, type): return dtype else: raise ValueError(f"Expected dtype to be DtypeLike, string, or None, got {type(dtype)}") def __call__(self, img: torch.Tensor) -> np.ndarray: if self.ch > 0: img = img[self.ch] img = self.activation(img.detach().cpu().float()).numpy() if self.threshold is not None: img = img > self.threshold if self.label: img = label(img) if self.rescale_dtype is not None: img = rescale_intensity(img, out_range=self.rescale_dtype).astype(self.rescale_dtype) return img.astype(self.dtype)