cyto_dl.image.transforms.generate_jepa_masks module#
- class cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator(spatial_dims: int, mask_size: int = 12, block_aspect_ratio: Tuple[float] = (0.5, 1.5), num_patches: Tuple[float] = (6, 24, 24), mask_ratio: float = 0.9)[source]#
Bases:
RandomizableTransform
Transform for generating Block-contiguous masks for JEPA training.
This class works by randomly adding mask blocks until the mask_ratio is met or exceeded, then removing blocks from the exterior of the contiguous mask until the mask_ratio is met exactly.
- Parameters:
spatial_dims (int) – The number of spatial dimensions of the image (2 or 3)
mask_size (int, optional) – The size of the blocks used to generate mask. Block dimensions are determined by the mask size and an aspect ratio sampled from the range block_aspect_ratio
block_aspect_ratio (Tuple[float], optional) – The low and high values for aspect ratio of the mask blocks
num_patches (Tuple[int], optional) – The number of patches used by the encoder for each dimension of the image (ZYX for 3D, YX for 2D)
mask_ratio (float, optional) – The proportion of the image to be masked