cyto_dl.nn.vits.decoder module#

class cyto_dl.nn.vits.decoder.CrossMAE_Decoder(num_patches: int | List[int], spatial_dims: int = 3, patch_size: int | List[int] | None = 4, enc_dim: int | None = 768, emb_dim: int | None = 192, num_layer: int | None = 4, num_head: int | None = 3, has_cls_token: bool | None = True, learnable_pos_embedding: bool | None = True)[source]#

Bases: MAE_Decoder

Decoder inspired by [CrossMAE](https://crossmae.github.io/) where masked tokens only attend to visible tokens.

Parameters:
  • num_patches (List[int], int) – Number of patches in each dimension. If int, the same number of patches is used for all dimensions.

  • patch_size (Tuple[int]) – Size of each patch in each dimension. If int, the same patch size is used for all dimensions.

  • enc_dim (int) – Dimension of encoder

  • emb_dim (int) – Dimension of embedding

  • num_layer (int) – Number of transformer layers

  • num_head (int) – Number of heads in transformer

  • has_cls_token (bool) – Whether encoder features have a cls token

  • learnable_pos_embedding (bool) – If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings are used. Empirically, fixed positional embeddings work better for brightfield images.

forward(features, forward_indexes, backward_indexes)[source]#
class cyto_dl.nn.vits.decoder.MAE_Decoder(num_patches: int | List[int], spatial_dims: int = 3, patch_size: int | List[int] | None = 4, enc_dim: int | None = 768, emb_dim: int | None = 192, num_layer: int | None = 4, num_head: int | None = 3, has_cls_token: bool | None = False, learnable_pos_embedding: bool | None = True)[source]#

Bases: Module

Parameters:
  • num_patches (List[int], int) – Number of patches in each dimension. If int, the same number of patches is used for all dimensions.

  • patch_size (Tuple[int], int) – Size of each patch. If int, the same patch size is used for all dimensions.

  • enc_dim (int) – Dimension of encoder

  • emb_dim (int) – Dimension of decoder

  • num_layer (int) – Number of transformer layers

  • num_head (int) – Number of heads in transformer

  • has_cls_token (bool) – Whether encoder features have a cls token

  • learnable_pos_embedding (bool) – If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images.

add_mask_tokens(features, backward_indexes)[source]#
adjust_indices_for_cls(indexes)[source]#
forward(features, forward_indexes, backward_indexes)[source]#
init_weight()[source]#