Source code for cyto_dl.nn.vits.mae

# modified from https://github.com/IcarusWizard/MAE/blob/main/model.py#L124

from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union

import numpy as np
import torch

from cyto_dl.nn.vits.decoder import CrossMAE_Decoder, MAE_Decoder
from cyto_dl.nn.vits.encoder import HieraEncoder, MAE_Encoder
from cyto_dl.nn.vits.utils import match_tuple_dimensions


[docs]class MAE_Base(torch.nn.Module, ABC): def __init__( self, spatial_dims, num_patches, patch_size, mask_ratio, features_only, context_pixels ): super().__init__() num_patches, patch_size, context_pixels = match_tuple_dimensions( spatial_dims, [num_patches, patch_size, context_pixels] ) self.spatial_dims = spatial_dims self.num_patches = num_patches self.patch_size = patch_size self.mask_ratio = mask_ratio self.features_only = features_only self.context_pixels = context_pixels # encoder and decoder must be defined in subclasses @property @abstractmethod def encoder(self): pass @property @abstractmethod def decoder(self): pass
[docs] def init_encoder(self): raise NotImplementedError
[docs] def init_decoder(self): raise NotImplementedError
[docs] def forward(self, img): features, mask, forward_indexes, backward_indexes = self.encoder(img, self.mask_ratio) if self.features_only: return features predicted_img = self.decoder(features, forward_indexes, backward_indexes) return predicted_img, mask
[docs]class MAE(MAE_Base): def __init__( self, spatial_dims: int = 3, num_patches: Optional[List[int]] = 16, patch_size: Optional[List[int]] = 4, emb_dim: Optional[int] = 768, encoder_layer: Optional[int] = 12, encoder_head: Optional[int] = 8, decoder_layer: Optional[int] = 4, decoder_head: Optional[int] = 8, decoder_dim: Optional[int] = 192, mask_ratio: Optional[int] = 0.75, use_crossmae: Optional[bool] = False, context_pixels: Optional[List[int]] = 0, input_channels: Optional[int] = 1, features_only: Optional[bool] = False, learnable_pos_embedding: Optional[bool] = True, ) -> None: """ Parameters ---------- spatial_dims: int Number of spatial dimensions num_patches: List[int] Number of patches in each dimension (ZYX order) patch_size: List[int] Size of each patch (ZYX order) emb_dim: int Dimension of encoder embedding encoder_layer: int Number of encoder transformer layers encoder_head: int Number of encoder heads decoder_layer: int Number of decoder transformer layers decoder_head: int Number of decoder heads mask_ratio: float Ratio of patches to mask out use_crossmae: bool Use CrossMAE-style decoder context_pixels: List[int] Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. input_channels: int Number of input channels features_only: bool Only use encoder to extract features learnable_pos_embedding: bool If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. """ super().__init__( spatial_dims=spatial_dims, num_patches=num_patches, patch_size=patch_size, mask_ratio=mask_ratio, features_only=features_only, context_pixels=context_pixels, ) self._encoder = MAE_Encoder( self.num_patches, spatial_dims, self.patch_size, emb_dim, encoder_layer, encoder_head, self.context_pixels, input_channels, n_intermediate_weights=-1 if not use_crossmae else decoder_layer, ) decoder_class = MAE_Decoder if use_crossmae: decoder_class = CrossMAE_Decoder self._decoder = decoder_class( num_patches=self.num_patches, spatial_dims=spatial_dims, patch_size=self.patch_size, enc_dim=emb_dim, emb_dim=decoder_dim, num_layer=decoder_layer, num_head=decoder_head, learnable_pos_embedding=learnable_pos_embedding, ) @property def encoder(self): return self._encoder @property def decoder(self): return self._decoder
[docs]class HieraMAE(MAE_Base): def __init__( self, architecture: List[Dict], spatial_dims: int = 3, num_patches: Optional[Union[int, List[int]]] = 16, num_mask_units: Optional[Union[int, List[int]]] = 8, patch_size: Optional[Union[int, List[int]]] = 4, emb_dim: Optional[int] = 64, decoder_layer: Optional[int] = 4, decoder_head: Optional[int] = 8, decoder_dim: Optional[int] = 192, mask_ratio: Optional[int] = 0.75, use_crossmae: Optional[bool] = False, context_pixels: Optional[List[int]] = 0, input_channels: Optional[int] = 1, features_only: Optional[bool] = False, ) -> None: """ Parameters ---------- architecture: List[Dict] List of dictionaries specifying the architecture of the transformer. Each dictionary should have the following keys: - repeat: int Number of times to repeat the block - num_heads: int Number of heads in the multihead attention - q_stride: List[int] Stride for the query in each spatial dimension - self_attention: bool Whether to use self attention or mask unit attention spatial_dims: int Number of spatial dimensions num_patches: int, List[int] Number of patches in each dimension (Z)YX order. If int, the same number of patches is used in each dimension. num_mask_units: int, List[int] Number of mask units in each dimension (Z)YX order. If int, the same number of mask units is used in each dimension. patch_size: int, List[int] Size of each patch (Z)YX order. If int, the same patch size is used in each dimension. emb_dim: int Dimension of embedding decoder_layer: int Number of layers in the decoder decoder_head: int Number of heads in the decoder decoder_dim: int Dimension of the decoder mask_ratio: float Fraction of mask units to remove use_crossmae: bool Use CrossMAE-style decoder instead of MAE decoder context_pixels: List[int] Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. input_channels: int Number of input channels features_only: bool Only use encoder to extract features """ super().__init__( spatial_dims=spatial_dims, num_patches=num_patches, patch_size=patch_size, mask_ratio=mask_ratio, features_only=features_only, context_pixels=context_pixels, ) num_mask_units = match_tuple_dimensions(self.spatial_dims, [num_mask_units])[0] self._encoder = HieraEncoder( num_patches=self.num_patches, num_mask_units=num_mask_units, architecture=architecture, emb_dim=emb_dim, spatial_dims=self.spatial_dims, patch_size=self.patch_size, context_pixels=self.context_pixels, input_channels=input_channels, ) # "patches" to the decoder are actually mask units, so num_patches is num_mask_units, patch_size is mask unit size mask_unit_size = (np.array(self.num_patches) * np.array(self.patch_size)) / np.array( num_mask_units ) decoder_class = MAE_Decoder if use_crossmae: decoder_class = CrossMAE_Decoder self._decoder = decoder_class( num_patches=num_mask_units, spatial_dims=spatial_dims, patch_size=mask_unit_size.astype(int), enc_dim=self.encoder.final_dim, emb_dim=decoder_dim, num_layer=decoder_layer, num_head=decoder_head, has_cls_token=False, ) @property def encoder(self): return self._encoder @property def decoder(self): return self._decoder