cyto_dl.nn.vits.mae module#

class cyto_dl.nn.vits.mae.HieraMAE(architecture: List[Dict], spatial_dims: int = 3, num_patches: int | List[int] | None = 16, num_mask_units: int | List[int] | None = 8, patch_size: int | List[int] | None = 4, emb_dim: int | None = 64, decoder_layer: int | None = 4, decoder_head: int | None = 8, decoder_dim: int | None = 192, mask_ratio: int | None = 0.75, use_crossmae: bool | None = False, context_pixels: List[int] | None = 0, input_channels: int | None = 1, features_only: bool | None = False)[source]#

Bases: MAE_Base

Parameters:
  • architecture (List[Dict]) – List of dictionaries specifying the architecture of the transformer. Each dictionary should have the following keys: - repeat: int

    Number of times to repeat the block

    • num_heads: int

      Number of heads in the multihead attention

    • q_stride: List[int]

      Stride for the query in each spatial dimension

    • self_attention: bool

      Whether to use self attention or mask unit attention

  • spatial_dims (int) – Number of spatial dimensions

  • num_patches (int, List[int]) – Number of patches in each dimension (Z)YX order. If int, the same number of patches is used in each dimension.

  • num_mask_units (int, List[int]) – Number of mask units in each dimension (Z)YX order. If int, the same number of mask units is used in each dimension.

  • patch_size (int, List[int]) – Size of each patch (Z)YX order. If int, the same patch size is used in each dimension.

  • emb_dim (int) – Dimension of embedding

  • decoder_layer (int) – Number of layers in the decoder

  • decoder_head (int) – Number of heads in the decoder

  • decoder_dim (int) – Dimension of the decoder

  • mask_ratio (float) – Fraction of mask units to remove

  • use_crossmae (bool) – Use CrossMAE-style decoder instead of MAE decoder

  • context_pixels (List[int]) – Number of extra pixels around each patch to include in convolutional embedding to encoder dimension.

  • input_channels (int) – Number of input channels

  • features_only (bool) – Only use encoder to extract features

property decoder#
property encoder#
class cyto_dl.nn.vits.mae.MAE(spatial_dims: int = 3, num_patches: List[int] | None = 16, patch_size: List[int] | None = 4, emb_dim: int | None = 768, encoder_layer: int | None = 12, encoder_head: int | None = 8, decoder_layer: int | None = 4, decoder_head: int | None = 8, decoder_dim: int | None = 192, mask_ratio: int | None = 0.75, use_crossmae: bool | None = False, context_pixels: List[int] | None = 0, input_channels: int | None = 1, features_only: bool | None = False, learnable_pos_embedding: bool | None = True)[source]#

Bases: MAE_Base

Parameters:
  • spatial_dims (int) – Number of spatial dimensions

  • num_patches (List[int]) – Number of patches in each dimension (ZYX order)

  • patch_size (List[int]) – Size of each patch (ZYX order)

  • emb_dim (int) – Dimension of encoder embedding

  • encoder_layer (int) – Number of encoder transformer layers

  • encoder_head (int) – Number of encoder heads

  • decoder_layer (int) – Number of decoder transformer layers

  • decoder_head (int) – Number of decoder heads

  • mask_ratio (float) – Ratio of patches to mask out

  • use_crossmae (bool) – Use CrossMAE-style decoder

  • context_pixels (List[int]) – Number of extra pixels around each patch to include in convolutional embedding to encoder dimension.

  • input_channels (int) – Number of input channels

  • features_only (bool) – Only use encoder to extract features

  • learnable_pos_embedding (bool) – If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images.

property decoder#
property encoder#
class cyto_dl.nn.vits.mae.MAE_Base(spatial_dims, num_patches, patch_size, mask_ratio, features_only, context_pixels)[source]#

Bases: Module, ABC

abstract property decoder#
abstract property encoder#
forward(img)[source]#
init_decoder()[source]#
init_encoder()[source]#