Source code for cyto_dl.models.im2im.utils.inferers

import torch
from monai.inferers import Merger, PatchInferer


[docs]class EmbeddingPatchMerger(Merger): def __init__(self, spatial_dims: int, **kwargs) -> None: super().__init__(**kwargs) self.values = [] if spatial_dims not in (2, 3): raise ValueError(f"Expected spatial_dims to be 2 or 3, got {spatial_dims}") self.spatial_dims = spatial_dims self.dim_names = ["z", "y", "x"][-spatial_dims:] self.locations = {f"start_{ax}": [] for ax in self.dim_names}
[docs] def aggregate(self, values, locations): """Aggregate values for merging. Args: values: a tensor of shape BC[Z]YX, representing the values of inference output. location: a tuple/list giving the top left location of the patch in the original image. """ b = values.shape[0] self.values.append(values) if len(locations) != self.spatial_dims: raise ValueError( f"Expected {self.spatial_dims} spatial dimensions, got {len(locations)}" ) # cast to string for saving to csv for axis, loc in zip(self.dim_names, locations): loc = [loc] * b if b > 1 else loc self.locations[f"start_{axis}"].append(loc)
[docs] def finalize(self): """Finalize the merging process and return the aggregated values. Returns: Stacked embeddings, shape n_patches x embedding_dim Stacked locations, shape n_patches x spatial dims of input """ return torch.cat(self.values, dim=0), self.locations
[docs]class EmbeddingPatchInferer(PatchInferer): """This overrides the PatchInferer to allow models that embed a spatial input to a single latent vector to have access to input coordinates of the patch. The typical use of the PatchInferer is for image-to-image applications, where the input and output are both spatial but might have different spatial sizes (e.g. in superresolution the output it larger than the input). The ratio is used to calculate the location of an output patch in the input image, so making the "ratio" between the input and output 1.0 in all dimensions associates each latent dimension with the patch it came from. """ def _initialize_mergers(self, *args, **kwargs): mergers, ratios = super()._initialize_mergers(*args, **kwargs) ratios = [tuple([1.0] * len(self.splitter.patch_size)) for r in ratios] return mergers, ratios