cyto_dl.nn.vits.encoder module#

class cyto_dl.nn.vits.encoder.HieraEncoder(num_patches: int | List[int], num_mask_units: int | List[int], architecture: List[Dict], emb_dim: int = 64, spatial_dims: int = 3, patch_size: int | List[int] = 4, context_pixels: int | List[int] | None = 0, input_channels: int | None = 1, save_layers: bool | None = False)[source]#

Bases: Module

Parameters:
  • num_patches (int, List[int]) – Number of patches in each dimension. If a single int is provided, the number of patches in each dimension will be the same.

  • num_mask_units (int, List[int]) – Number of mask units in each dimension. If a single int is provided, the number of mask units in each dimension will be the same.

  • architecture (List[Dict]) – List of dictionaries specifying the architecture of the transformer. Each dictionary should have the following keys: - repeat: int

    Number of times to repeat the block

    • num_heads: int

      Number of heads in the multihead attention

    • q_stride: int, List[int]

      Stride for the query in each spatial dimension

    • self_attention: bool

      Whether to use self attention or mask unit attention

    On the last repeat of each non-self-attention block, the embedding dimension is doubled and spatial pooling with q_stride is performed within each mask unit. For example, a block with a embed_dim=4, q_stride=2, and repeat=2, the first repeat just does mask unit attention, while the second will produce an 8-dimensional output that has been spatially pooled.

  • emb_dim (int) – Dimension of embedding

  • spatial_dims (int) – Number of spatial dimensions

  • patch_size (List[int]) – Size of each patch

  • context_pixels (List[int]) – Number of extra pixels around each patch to include in convolutional embedding to encoder dimension.

  • input_channels (int) – Number of input channels

  • save_layers (bool) – Whether to save the intermediate layer outputs

forward(img, mask_ratio)[source]#
class cyto_dl.nn.vits.encoder.JEPAEncoder(num_patches: int | List[int], spatial_dims: int = 3, patch_size: int | List[int] = 4, emb_dim: int | None = 192, num_layer: int | None = 12, num_head: int | None = 3, context_pixels: int | List[int] | None = 0, input_channels: int | None = 1, learnable_pos_embedding: bool | None = True)[source]#

Bases: Module

Parameters:
  • num_patches (List[int], int) – Number of patches in each dimension. If a single int is provided, the number of patches in each dimension will be the same.

  • spatial_dims (int) – Number of spatial dimensions

  • patch_size (List[int]) – Size of each patch

  • emb_dim (int) – Dimension of embedding

  • num_layer (int) – Number of transformer layers

  • num_head (int) – Number of heads in transformer

  • context_pixels (List[int], int) – Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. If a single int is provided, the number of context pixels in each dimension will be the same.

  • input_channels (int) – Number of input channels

  • learnable_pos_embedding (bool) – If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images.

forward(img, patchify=True)[source]#
class cyto_dl.nn.vits.encoder.MAE_Encoder(num_patches: List[int], spatial_dims: int = 3, patch_size: int | List[int] = 4, emb_dim: int | None = 192, num_layer: int | None = 12, num_head: int | None = 3, context_pixels: int | List[int] | None = 0, input_channels: int | None = 1, n_intermediate_weights: int | None = -1)[source]#

Bases: Module

Parameters:
  • num_patches (List[int], int) – Number of patches in each dimension. If a single int is provided, the number of patches in each dimension will be the same.

  • spatial_dims (int) – Number of spatial dimensions

  • patch_size (List[int]) – Size of each patch

  • emb_dim (int) – Dimension of embedding

  • num_layer (int) – Number of transformer layers

  • num_head (int) – Number of heads in transformer

  • context_pixels (List[int], int) – Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. If a single int is provided, the number of context pixels in each dimension will be the same.

  • input_channels (int) – Number of input channels

  • n_intermediate_weights (bool) – Whether to use intermediate weights for weighted sum of intermediate layers

forward(img, mask_ratio=0.75)[source]#
init_weight()[source]#
class cyto_dl.nn.vits.encoder.SpatialMerger(downsample_factor: List[int], in_dim: int, out_dim: int, spatial_dims: int = 3)[source]#

Bases: Module

Class for converting multi-resolution Hiera features to the same (lowest) spatial resolution via convolution.

forward(x)[source]#