from typing import List, Optional, Union

import torch
import torch.nn as nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block

from cyto_dl.nn.vits.blocks import CrossAttentionBlock
from cyto_dl.nn.vits.utils import (

[docs]class MAE_Decoder(torch.nn.Module): def __init__( self, num_patches: Union[int, List[int]], spatial_dims: int = 3, patch_size: Optional[Union[int, List[int]]] = 4, enc_dim: Optional[int] = 768, emb_dim: Optional[int] = 192, num_layer: Optional[int] = 4, num_head: Optional[int] = 3, has_cls_token: Optional[bool] = False, learnable_pos_embedding: Optional[bool] = True, ) -> None: """ Parameters ---------- num_patches: List[int], int Number of patches in each dimension. If int, the same number of patches is used for all dimensions. patch_size: Tuple[int], int Size of each patch. If int, the same patch size is used for all dimensions. enc_dim: int Dimension of encoder emb_dim: int Dimension of decoder num_layer: int Number of transformer layers num_head: int Number of heads in transformer has_cls_token: bool Whether encoder features have a cls token learnable_pos_embedding: bool If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. """ super().__init__() num_patches, patch_size = match_tuple_dimensions(spatial_dims, [num_patches, patch_size]) self.has_cls_token = has_cls_token self.projection_norm = nn.LayerNorm(emb_dim) self.projection = torch.nn.Linear(enc_dim, emb_dim) self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) self.pos_embedding = get_positional_embedding( num_patches, emb_dim, use_cls_token=has_cls_token, learnable=learnable_pos_embedding ) self.transformer = torch.nn.Sequential( *[Block(emb_dim, num_head) for _ in range(num_layer)] ) out_dim = self.decoder_norm = nn.LayerNorm(emb_dim) self.head = torch.nn.Linear(emb_dim, out_dim) self.num_patches = torch.as_tensor(num_patches) if spatial_dims == 3: self.patch2img = Rearrange( "(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", n_patch_z=num_patches[0], n_patch_y=num_patches[1], n_patch_x=num_patches[2], patch_size_z=patch_size[0], patch_size_y=patch_size[1], patch_size_x=patch_size[2], ) elif spatial_dims == 2: self.patch2img = Rearrange( "(n_patch_y n_patch_x) b (c patch_size_y patch_size_x) -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", n_patch_y=num_patches[0], n_patch_x=num_patches[1], patch_size_y=patch_size[0], patch_size_x=patch_size[1], ) self.init_weight()
[docs] def init_weight(self): trunc_normal_(self.mask_token, std=0.02)
[docs] def adjust_indices_for_cls(self, indexes): if self.has_cls_token: # add all zeros to indices - this keeps the class token as the first index in the # tensor. We also have to add 1 to all the indices to account for the zeros we added return [ torch.zeros(1, indexes.shape[1], device=indexes.device, dtype=torch.long), indexes + 1, ], dim=0, ) return indexes
[docs] def add_mask_tokens(self, features, backward_indexes): # fill in deleted masked regions with mask token num_visible_tokens, B, _ = features.shape # total number of tokens - number of visible tokens num_mask_tokens = backward_indexes.shape[0] - num_visible_tokens mask_tokens = repeat(self.mask_token, "1 1 c -> t b c", t=num_mask_tokens, b=B) return[features, mask_tokens], dim=0)
[docs] def forward(self, features, forward_indexes, backward_indexes): # project from encoder dimension to decoder dimension features = self.projection_norm(self.projection(features)) backward_indexes = self.adjust_indices_for_cls(backward_indexes) features = self.add_mask_tokens(features, backward_indexes) # unshuffle to original positions features = take_indexes(features, backward_indexes) features = features + self.pos_embedding # decode features = rearrange(features, "t b c -> b t c") features = self.transformer(features) features = rearrange(features, "b t c -> t b c") if self.has_cls_token: features = features[1:] # remove global feature # (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x) patches = self.head(self.decoder_norm(features)) # patches to image img = self.patch2img(patches) return img
[docs]class CrossMAE_Decoder(MAE_Decoder): """Decoder inspired by [CrossMAE]( where masked tokens only attend to visible tokens.""" def __init__( self, num_patches: Union[int, List[int]], spatial_dims: int = 3, patch_size: Optional[Union[int, List[int]]] = 4, enc_dim: Optional[int] = 768, emb_dim: Optional[int] = 192, num_layer: Optional[int] = 4, num_head: Optional[int] = 3, has_cls_token: Optional[bool] = True, learnable_pos_embedding: Optional[bool] = True, ) -> None: """ Parameters ---------- num_patches: List[int], int Number of patches in each dimension. If int, the same number of patches is used for all dimensions. patch_size: Tuple[int] Size of each patch in each dimension. If int, the same patch size is used for all dimensions. enc_dim: int Dimension of encoder emb_dim: int Dimension of embedding num_layer: int Number of transformer layers num_head: int Number of heads in transformer has_cls_token: bool Whether encoder features have a cls token learnable_pos_embedding: bool If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings are used. Empirically, fixed positional embeddings work better for brightfield images. """ super().__init__( num_patches, spatial_dims, patch_size, enc_dim, emb_dim, num_layer, num_head, has_cls_token, learnable_pos_embedding, ) self.transformer = torch.nn.ParameterList( [ CrossAttentionBlock( encoder_dim=emb_dim, decoder_dim=emb_dim, num_heads=num_head, ) for _ in range(num_layer) ] )
[docs] def forward(self, features, forward_indexes, backward_indexes): # HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 1 layer # features can be n t b c (if intermediate feature weighter used) or t b c if not features = features[0] if len(features.shape) == 4 else features T, B, C = features.shape # we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now features = self.projection_norm(self.projection(features)) backward_indexes = self.adjust_indices_for_cls(backward_indexes) forward_indexes = self.adjust_indices_for_cls(forward_indexes) features = self.add_mask_tokens(features, backward_indexes) # unshuffle to original positions for positional embedding so we can do cross attention during decoding features = take_indexes(features, backward_indexes) features = features + self.pos_embedding # reshuffle to shuffled positions for cross attention features = take_indexes(features, forward_indexes) features, masked = features[:T], features[T:] masked = rearrange(masked, "t b c -> b t c") features = rearrange(features, "t b c -> b t c") for transformer in self.transformer: masked = transformer(masked, features) # noneed to remove cls token, it's a part of the features vector masked = rearrange(masked, "b t c -> t b c") # (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x) masked = self.decoder_norm(masked) patches = self.head(masked) # add back in visible/encoded tokens that we don't calculate loss on patches = [ torch.zeros( # T-1 accounts for cls token (T - self.has_cls_token, B, patches.shape[-1]), requires_grad=False, device=patches.device, dtype=patches.dtype, ), patches, ], dim=0, ) patches = ( take_indexes(patches, backward_indexes[1:] - 1) if self.has_cls_token else take_indexes(patches, backward_indexes) ) # patches to image img = self.patch2img(patches) return img