cyto_dl.image.transforms.generate_jepa_masks module#

class cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator(spatial_dims: int, mask_size: int = 12, block_aspect_ratio: Tuple[float] = (0.5, 1.5), num_patches: Tuple[float] = (6, 24, 24), mask_ratio: float = 0.9)[source]#

Bases: RandomizableTransform

Transform for generating Block-contiguous masks for JEPA training.

This class works by randomly adding mask blocks until the mask_ratio is met or exceeded, then removing blocks from the exterior of the contiguous mask until the mask_ratio is met exactly.

Parameters:
  • spatial_dims (int) – The number of spatial dimensions of the image (2 or 3)

  • mask_size (int, optional) – The size of the blocks used to generate mask. Block dimensions are determined by the mask size and an aspect ratio sampled from the range block_aspect_ratio

  • block_aspect_ratio (Tuple[float], optional) – The low and high values for aspect ratio of the mask blocks

  • num_patches (Tuple[int], optional) – The number of patches used by the encoder for each dimension of the image (ZYX for 3D, YX for 2D)

  • mask_ratio (float, optional) – The proportion of the image to be masked

remove_excess_pixels(mask)[source]#

Remove pixels along the boundary of the mask until the target number of pixels is reached.