Source code for cyto_dl.nn.vits.seg

from typing import List, Optional, Union

import numpy as np
import torch
from einops.layers.torch import Rearrange
from monai.networks.blocks import UnetOutBlock, UnetResBlock, UpSample

from cyto_dl.nn.vits.mae import MAE_Encoder


[docs]class EncodedSkip(torch.nn.Module): def __init__(self, spatial_dims, num_patches, emb_dim, n_decoder_filters, layer): super().__init__() """ layer = 0 is the smallest resolution, n is the highest as the layer increases, the image size increases and the number of filters decreases """ upsample = 2**layer self.n_out_channels = n_decoder_filters // (upsample**spatial_dims) if spatial_dims == 3: rearrange = Rearrange( " (n_patch_z n_patch_y n_patch_x) b (c uz uy ux) -> b c (n_patch_z uz) (n_patch_y uy) (n_patch_x ux)", n_patch_z=num_patches[0], n_patch_y=num_patches[1], n_patch_x=num_patches[2], c=self.n_out_channels, uz=upsample, uy=upsample, ux=upsample, ) elif spatial_dims == 2: rearrange = Rearrange( " (n_patch_y n_patch_x) b (c uy ux) -> b c (n_patch_y uy) (n_patch_x ux)", n_patch_y=num_patches[0], n_patch_x=num_patches[1], c=self.n_out_channels, uy=upsample, ux=upsample, ) self.patch2img = torch.nn.Sequential( torch.nn.Linear(emb_dim, n_decoder_filters), torch.nn.LayerNorm(n_decoder_filters), rearrange, )
[docs] def forward(self, features): return self.patch2img(features)
[docs]class SuperresDecoder(torch.nn.Module): """create unet-like decoder where each decoder layer is a fed a skip connection consisting of a different weighted sum of intermediate layer features.""" def __init__( self, spatial_dims: int = 3, num_patches: Optional[List[int]] = [2, 32, 32], base_patch_size: Optional[List[int]] = [4, 8, 8], emb_dim: Optional[int] = 192, n_decoder_filters: Optional[int] = 16, out_channels: Optional[int] = 6, upsample_factor: Optional[Union[int, List[int]]] = [2.6134, 2.5005, 2.5005], num_layer: Optional[int] = 3, ) -> None: super().__init__() total_upsample_factor = np.array(upsample_factor) * np.array(base_patch_size) self.num_layer = np.min(np.log2(total_upsample_factor)).astype(int) print(self.num_layer) residual_resize_factor = list(total_upsample_factor / 2**self.num_layer) input_n_decoder_filters = n_decoder_filters self.upsampling = torch.nn.ModuleDict() for i in range(self.num_layer): skip = EncodedSkip(spatial_dims, num_patches, emb_dim, input_n_decoder_filters, i) n_input_channels = ( n_decoder_filters + skip.n_out_channels if i > 0 else n_decoder_filters ) self.upsampling[f"layer_{i}"] = torch.nn.ModuleDict( { "skip": skip, "upsample": torch.nn.Sequential( *[ UnetResBlock( spatial_dims=spatial_dims, in_channels=n_input_channels, out_channels=n_decoder_filters // 2, stride=1, kernel_size=3, norm_name="INSTANCE", dropout=0, ), # no convolution in upsample, do convolution at low resolution UpSample( spatial_dims=spatial_dims, in_channels=n_decoder_filters // 2, out_channels=n_decoder_filters // 2, scale_factor=[2] * spatial_dims, mode="nontrainable", ), ] ), } ) n_decoder_filters = n_decoder_filters // 2 skip = EncodedSkip( spatial_dims, num_patches, emb_dim, input_n_decoder_filters, self.num_layer ) n_input_channels = n_decoder_filters + skip.n_out_channels self.upsampling[f"layer_{i+1}"] = torch.nn.ModuleDict( { "skip": skip, "upsample": torch.nn.Sequential( *[ UpSample( spatial_dims=spatial_dims, in_channels=n_input_channels, out_channels=n_decoder_filters // 2, scale_factor=residual_resize_factor, mode="nontrainable", interp_mode="trilinear", ), UnetResBlock( spatial_dims=spatial_dims, in_channels=n_decoder_filters // 2, out_channels=n_decoder_filters // 2, stride=1, kernel_size=3, norm_name="INSTANCE", dropout=0, ), UnetOutBlock( spatial_dims=spatial_dims, in_channels=n_decoder_filters // 2, out_channels=out_channels, dropout=0, ), ] ), } )
[docs] def forward( self, features: torch.Tensor, ): # remove global token features = features[:, 1:] prev = None for i in range(self.num_layer + 1): skip = self.upsampling[f"layer_{i}"]["skip"](features[i]) if prev is not None: skip = torch.cat([skip, prev], dim=1) prev = self.upsampling[f"layer_{i}"]["upsample"](skip) return prev
[docs]class Seg_ViT(torch.nn.Module): """Class for training a simple convolutional decoder on top of a pretrained ViT backbone.""" def __init__( self, spatial_dims: int = 3, num_patches: Optional[List[int]] = [2, 32, 32], base_patch_size: Optional[List[int]] = [16, 16, 16], emb_dim: Optional[int] = 768, decoder_layer: Optional[int] = 3, n_decoder_filters: Optional[int] = 16, out_channels: Optional[int] = 6, upsample_factor: Optional[List[int]] = [2.6134, 2.5005, 2.5005], encoder_ckpt: Optional[str] = None, freeze_encoder: Optional[bool] = True, **encoder_kwargs, ) -> None: """ Parameters ---------- spatial_dims: Optional[int]=3 Number of spatial dimensions num_patches: Optional[List[int]]=[2, 32, 32] Number of patches in each dimension (ZYX) order base_patch_size: Optional[List[int]]=[16, 16, 16] Base patch size in each dimension (ZYX) order emb_dim: Optional[int] =768 Embedding dimension of ViT backbone encoder_layer: Optional[int] =12 Number of layers in ViT backbone encoder_head: Optional[int] =8 Number of heads in ViT backbone decoder_layer: Optional[int] =3 Number of layers in convolutional decoder n_decoder_filters: Optional[int] =16 Number of filters in convolutional decoder out_channels: Optional[int] =6 Number of output channels in convolutional decoder. Should be 6 for instance segmentation. mask_ratio: Optional[int] =0.75 Ratio of patches to be masked out during training upsample_factor:Optional[List[int]] = [2.6134, 2.5005, 2.5005] Upsampling factor for each dimension (ZYX) order. Default is AICS 20x to 100x object upsampling encoder_ckpt: Optional[str]=None Path to pretrained ViT backbone checkpoint """ super().__init__() assert spatial_dims in (2, 3) if isinstance(num_patches, int): num_patches = [num_patches] * spatial_dims if isinstance(base_patch_size, int): base_patch_size = [base_patch_size] * spatial_dims if isinstance(upsample_factor, int): upsample_factor = [upsample_factor] * spatial_dims assert len(num_patches) == spatial_dims assert len(base_patch_size) == spatial_dims assert len(upsample_factor) == spatial_dims self.encoder = MAE_Encoder( spatial_dims=spatial_dims, num_patches=num_patches, base_patch_size=base_patch_size, emb_dim=emb_dim, **encoder_kwargs, ) if encoder_ckpt is not None: model = torch.load(encoder_ckpt, map_location="cuda:0") enc_state_dict = { k.replace("backbone.encoder.", ""): v for k, v in model["state_dict"].items() if "encoder" in k and "intermediate" not in k } self.encoder.load_state_dict(enc_state_dict, strict=False) if freeze_encoder: for name, param in self.encoder.named_parameters(): # allow different weighting of internal activations for finetuning param.requires_grad = "intermediate_weighter" in name self.decoder = SuperresDecoder( spatial_dims=spatial_dims, num_patches=num_patches, base_patch_size=base_patch_size, emb_dim=emb_dim, num_layer=decoder_layer, n_decoder_filters=n_decoder_filters, out_channels=out_channels, upsample_factor=upsample_factor, )
[docs] def forward(self, img): features = self.encoder(img, mask_ratio=0) return self.decoder(features)