Source code for cyto_dl.image.transforms.generate_jepa_masks

from typing import Tuple

import numpy as np
from einops import rearrange
from monai.transforms import RandomizableTransform
from skimage.segmentation import find_boundaries

from cyto_dl.nn.vits.utils import match_tuple_dimensions


[docs]class JEPAMaskGenerator(RandomizableTransform): """Transform for generating Block-contiguous masks for JEPA training. This class works by randomly adding mask blocks until the mask_ratio is met or exceeded, then removing blocks from the exterior of the contiguous mask until the mask_ratio is met exactly. """ def __init__( self, spatial_dims: int, mask_size: int = 12, block_aspect_ratio: Tuple[float] = (0.5, 1.5), num_patches: Tuple[float] = (6, 24, 24), mask_ratio: float = 0.9, ): """ Parameters ---------- spatial_dims : int The number of spatial dimensions of the image (2 or 3) mask_size : int, optional The size of the blocks used to generate mask. Block dimensions are determined by the mask size and an aspect ratio sampled from the range `block_aspect_ratio` block_aspect_ratio : Tuple[float], optional The low and high values for aspect ratio of the mask blocks num_patches : Tuple[int], optional The number of patches used by the encoder for each dimension of the image (ZYX for 3D, YX for 2D) mask_ratio : float, optional The proportion of the image to be masked """ assert 0 < mask_ratio < 1, "mask_ratio must be between 0 and 1" num_patches = match_tuple_dimensions(spatial_dims, [num_patches])[0] assert mask_size * max(block_aspect_ratio) < min( num_patches[-2:] ), "mask_size * max mask aspect ratio must be less than the smallest dimension of num_patches" self.mask_size = mask_size self.block_aspect_ratio = block_aspect_ratio self.num_patches = num_patches # convert mask_ratio to number of pixels to be masked self.target_pix = int(mask_ratio * np.prod(num_patches)) self.mask = np.zeros(num_patches) self.edge_mask = np.ones(num_patches) self.spatial_dims = spatial_dims # create a mask that identified pixels on the edge of the image if self.spatial_dims == 3: self.edge_mask[1:-1, 1:-1, 1:-1] = 0 elif self.spatial_dims == 2: self.edge_mask[1:-1, 1:-1] = 0 else: raise ValueError("num_patches must be 2 or 3 dimensions")
[docs] def remove_excess_pixels(self, mask): """Remove pixels along the boundary of the mask until the target number of pixels is reached.""" bound = find_boundaries(mask, mode="inner") # include image edge as boundary, not just 1:0 transitions edge_mask = np.logical_and(mask, self.edge_mask) bound = np.logical_or(bound, edge_mask) bound_coords = np.argwhere(bound) # find number of pixels to remove from contiguous mask excess = int(mask.sum() - self.target_pix) remove = self.R.choice(range(bound_coords.shape[0]), excess, replace=False) remove_coords = bound_coords[remove] if self.spatial_dims == 3: mask[remove_coords[:, 0], remove_coords[:, 1], remove_coords[:, 2]] = 0 mask = rearrange(mask, "z y x -> (z y x)").astype(bool) else: mask[remove_coords[:, 0], remove_coords[:, 1]] = 0 mask = rearrange(mask, "y x -> (y x)").astype(bool) return mask
def __call__(self, img_dict): # generate context (small) and target(large) masks # target: add blocks until target_pix is reached, then randomly remove excess border pixels # context: invert of target mask mask = self.mask.copy() while mask.sum() < self.target_pix: # randomly select block shape aspect_ratio = self.R.uniform(*self.block_aspect_ratio) width = int(self.mask_size * aspect_ratio) height = int(self.mask_size / aspect_ratio) # randomly select block position x = self.R.randint(0, self.num_patches[-1] - width + 1) y = self.R.randint(0, self.num_patches[-2] - height + 1) # add block to mask if self.spatial_dims == 3: mask[:, y : y + height, x : x + width] = 1 else: mask[y : y + height, x : x + width] = 1 mask = self.remove_excess_pixels(mask) context_mask = np.argwhere(~mask).squeeze() target_mask = np.argwhere(mask).squeeze() img_dict["context_mask"] = context_mask img_dict["target_mask"] = target_mask return img_dict