cyto_dl.nn.vits.seg module#

class cyto_dl.nn.vits.seg.EncodedSkip(spatial_dims, num_patches, emb_dim, n_decoder_filters, layer)[source]#

Bases: Module

forward(features)[source]#
class cyto_dl.nn.vits.seg.Seg_ViT(spatial_dims: int = 3, num_patches: List[int] | None = [2, 32, 32], base_patch_size: List[int] | None = [16, 16, 16], emb_dim: int | None = 768, decoder_layer: int | None = 3, n_decoder_filters: int | None = 16, out_channels: int | None = 6, upsample_factor: List[int] | None = [2.6134, 2.5005, 2.5005], encoder_ckpt: str | None = None, freeze_encoder: bool | None = True, **encoder_kwargs)[source]#

Bases: Module

Class for training a simple convolutional decoder on top of a pretrained ViT backbone.

Parameters:
  • spatial_dims (Optional[int]=3) – Number of spatial dimensions

  • num_patches (Optional[List[int]]=[2, 32, 32]) – Number of patches in each dimension (ZYX) order

  • base_patch_size (Optional[List[int]]=[16, 16, 16]) – Base patch size in each dimension (ZYX) order

  • emb_dim (Optional[int] =768) – Embedding dimension of ViT backbone

  • encoder_layer (Optional[int] =12) – Number of layers in ViT backbone

  • encoder_head (Optional[int] =8) – Number of heads in ViT backbone

  • decoder_layer (Optional[int] =3) – Number of layers in convolutional decoder

  • n_decoder_filters (Optional[int] =16) – Number of filters in convolutional decoder

  • out_channels (Optional[int] =6) – Number of output channels in convolutional decoder. Should be 6 for instance segmentation.

  • mask_ratio (Optional[int] =0.75) – Ratio of patches to be masked out during training

  • upsample_factor (Optional[List[int]] = [2.6134, 2.5005, 2.5005]) – Upsampling factor for each dimension (ZYX) order. Default is AICS 20x to 100x object upsampling

  • encoder_ckpt (Optional[str]=None) – Path to pretrained ViT backbone checkpoint

forward(img)[source]#
class cyto_dl.nn.vits.seg.SuperresDecoder(spatial_dims: int = 3, num_patches: List[int] | None = [2, 32, 32], base_patch_size: List[int] | None = [4, 8, 8], emb_dim: int | None = 192, n_decoder_filters: int | None = 16, out_channels: int | None = 6, upsample_factor: int | List[int] | None = [2.6134, 2.5005, 2.5005], num_layer: int | None = 3)[source]#

Bases: Module

create unet-like decoder where each decoder layer is a fed a skip connection consisting of a different weighted sum of intermediate layer features.

forward(features: Tensor)[source]#