cyto_dl.nn.vits.seg module#
- class cyto_dl.nn.vits.seg.EncodedSkip(spatial_dims, num_patches, emb_dim, n_decoder_filters, layer)[source]#
Bases:
Module
- class cyto_dl.nn.vits.seg.Seg_ViT(spatial_dims: int = 3, num_patches: List[int] | None = [2, 32, 32], base_patch_size: List[int] | None = [16, 16, 16], emb_dim: int | None = 768, decoder_layer: int | None = 3, n_decoder_filters: int | None = 16, out_channels: int | None = 6, upsample_factor: List[int] | None = [2.6134, 2.5005, 2.5005], encoder_ckpt: str | None = None, freeze_encoder: bool | None = True, **encoder_kwargs)[source]#
Bases:
Module
Class for training a simple convolutional decoder on top of a pretrained ViT backbone.
- Parameters:
spatial_dims (Optional[int]=3) – Number of spatial dimensions
num_patches (Optional[List[int]]=[2, 32, 32]) – Number of patches in each dimension (ZYX) order
base_patch_size (Optional[List[int]]=[16, 16, 16]) – Base patch size in each dimension (ZYX) order
emb_dim (Optional[int] =768) – Embedding dimension of ViT backbone
encoder_layer (Optional[int] =12) – Number of layers in ViT backbone
encoder_head (Optional[int] =8) – Number of heads in ViT backbone
decoder_layer (Optional[int] =3) – Number of layers in convolutional decoder
n_decoder_filters (Optional[int] =16) – Number of filters in convolutional decoder
out_channels (Optional[int] =6) – Number of output channels in convolutional decoder. Should be 6 for instance segmentation.
mask_ratio (Optional[int] =0.75) – Ratio of patches to be masked out during training
upsample_factor (Optional[List[int]] = [2.6134, 2.5005, 2.5005]) – Upsampling factor for each dimension (ZYX) order. Default is AICS 20x to 100x object upsampling
encoder_ckpt (Optional[str]=None) – Path to pretrained ViT backbone checkpoint
- class cyto_dl.nn.vits.seg.SuperresDecoder(spatial_dims: int = 3, num_patches: List[int] | None = [2, 32, 32], base_patch_size: List[int] | None = [4, 8, 8], emb_dim: int | None = 192, n_decoder_filters: int | None = 16, out_channels: int | None = 6, upsample_factor: int | List[int] | None = [2.6134, 2.5005, 2.5005], num_layer: int | None = 3)[source]#
Bases:
Module
Create unet-like decoder where each decoder layer is a fed a skip connection consisting of a different weighted sum of intermediate layer features.