cyto_dl.nn.vits.encoder module#
- class cyto_dl.nn.vits.encoder.HieraEncoder(num_patches: int | List[int], num_mask_units: int | List[int], architecture: List[Dict], emb_dim: int = 64, spatial_dims: int = 3, patch_size: int | List[int] = 4, context_pixels: int | List[int] | None = 0, input_channels: int | None = 1, save_layers: bool | None = False)[source]#
Bases:
Module
- Parameters:
num_patches (int, List[int]) – Number of patches in each dimension. If a single int is provided, the number of patches in each dimension will be the same.
num_mask_units (int, List[int]) – Number of mask units in each dimension. If a single int is provided, the number of mask units in each dimension will be the same.
architecture (List[Dict]) – List of dictionaries specifying the architecture of the transformer. Each dictionary should have the following keys: - repeat: int
Number of times to repeat the block
- num_heads: int
Number of heads in the multihead attention
- q_stride: int, List[int]
Stride for the query in each spatial dimension
- self_attention: bool
Whether to use self attention or mask unit attention
On the last repeat of each non-self-attention block, the embedding dimension is doubled and spatial pooling with q_stride is performed within each mask unit. For example, a block with a embed_dim=4, q_stride=2, and repeat=2, the first repeat just does mask unit attention, while the second will produce an 8-dimensional output that has been spatially pooled.
emb_dim (int) – Dimension of embedding
spatial_dims (int) – Number of spatial dimensions
patch_size (List[int]) – Size of each patch
context_pixels (List[int]) – Number of extra pixels around each patch to include in convolutional embedding to encoder dimension.
input_channels (int) – Number of input channels
save_layers (bool) – Whether to save the intermediate layer outputs
- class cyto_dl.nn.vits.encoder.JEPAEncoder(num_patches: int | List[int], spatial_dims: int = 3, patch_size: int | List[int] = 4, emb_dim: int | None = 192, num_layer: int | None = 12, num_head: int | None = 3, context_pixels: int | List[int] | None = 0, input_channels: int | None = 1, learnable_pos_embedding: bool | None = True)[source]#
Bases:
Module
- Parameters:
num_patches (List[int], int) – Number of patches in each dimension. If a single int is provided, the number of patches in each dimension will be the same.
spatial_dims (int) – Number of spatial dimensions
patch_size (List[int]) – Size of each patch
emb_dim (int) – Dimension of embedding
num_layer (int) – Number of transformer layers
num_head (int) – Number of heads in transformer
context_pixels (List[int], int) – Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. If a single int is provided, the number of context pixels in each dimension will be the same.
input_channels (int) – Number of input channels
learnable_pos_embedding (bool) – If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images.
- class cyto_dl.nn.vits.encoder.MAE_Encoder(num_patches: List[int], spatial_dims: int = 3, patch_size: int | List[int] = 4, emb_dim: int | None = 192, num_layer: int | None = 12, num_head: int | None = 3, context_pixels: int | List[int] | None = 0, input_channels: int | None = 1, n_intermediate_weights: int | None = -1)[source]#
Bases:
Module
- Parameters:
num_patches (List[int], int) – Number of patches in each dimension. If a single int is provided, the number of patches in each dimension will be the same.
spatial_dims (int) – Number of spatial dimensions
patch_size (List[int]) – Size of each patch
emb_dim (int) – Dimension of embedding
num_layer (int) – Number of transformer layers
num_head (int) – Number of heads in transformer
context_pixels (List[int], int) – Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. If a single int is provided, the number of context pixels in each dimension will be the same.
input_channels (int) – Number of input channels
n_intermediate_weights (bool) – Whether to use intermediate weights for weighted sum of intermediate layers