cyto_dl.nn.vits.decoder module#
- class cyto_dl.nn.vits.decoder.CrossMAE_Decoder(num_patches: int | List[int], spatial_dims: int = 3, patch_size: int | List[int] | None = 4, enc_dim: int | None = 768, emb_dim: int | None = 192, num_layer: int | None = 4, num_head: int | None = 3, has_cls_token: bool | None = True, learnable_pos_embedding: bool | None = True)[source]#
Bases:
MAE_Decoder
Decoder inspired by [CrossMAE](https://crossmae.github.io/) where masked tokens only attend to visible tokens.
- Parameters:
num_patches (List[int], int) – Number of patches in each dimension. If int, the same number of patches is used for all dimensions.
patch_size (Tuple[int]) – Size of each patch in each dimension. If int, the same patch size is used for all dimensions.
enc_dim (int) – Dimension of encoder
emb_dim (int) – Dimension of embedding
num_layer (int) – Number of transformer layers
num_head (int) – Number of heads in transformer
has_cls_token (bool) – Whether encoder features have a cls token
learnable_pos_embedding (bool) – If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings are used. Empirically, fixed positional embeddings work better for brightfield images.
- class cyto_dl.nn.vits.decoder.MAE_Decoder(num_patches: int | List[int], spatial_dims: int = 3, patch_size: int | List[int] | None = 4, enc_dim: int | None = 768, emb_dim: int | None = 192, num_layer: int | None = 4, num_head: int | None = 3, has_cls_token: bool | None = False, learnable_pos_embedding: bool | None = True)[source]#
Bases:
Module
- Parameters:
num_patches (List[int], int) – Number of patches in each dimension. If int, the same number of patches is used for all dimensions.
patch_size (Tuple[int], int) – Size of each patch. If int, the same patch size is used for all dimensions.
enc_dim (int) – Dimension of encoder
emb_dim (int) – Dimension of decoder
num_layer (int) – Number of transformer layers
num_head (int) – Number of heads in transformer
has_cls_token (bool) – Whether encoder features have a cls token
learnable_pos_embedding (bool) – If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images.