Source code for cyto_dl.nn.track_sequence_predictor

import math

import torch
from einops import rearrange
from monai.networks.nets import Regressor
from timm.models.vision_transformer import Block


[docs]class TrackClassifier(torch.nn.Module): """Transformer that encodes images in a sequence as tokens and classifies each of them.""" def __init__( self, patch_size, pos_embedding_length=120, emb_dim=128, num_layer=8, num_head=4, num_classes=2, ) -> None: super().__init__() self.register_buffer("pos_embedding", positionalencoding1d(emb_dim, pos_embedding_length)) self.pos_embedding.requires_grad = False self.image_encoder = Regressor( in_shape=patch_size, out_shape=emb_dim, channels=[8, 16, 32], strides=[2, 2, 2] ) self.transformer = torch.nn.Sequential( *[Block(emb_dim, num_head) for _ in range(num_layer)] ) self.layer_norm = torch.nn.LayerNorm(emb_dim) self.classifier = torch.nn.Linear(emb_dim, num_classes)
[docs] def forward(self, img): img = img.as_tensor() # move track length to front to embed images individually img = rearrange(img, "b track_len h w -> track_len b h w") # returns token x batch x embedding dim patches = self.image_encoder(img).unsqueeze(1) # interpolate positional embedding to match track length pe = torch.nn.functional.interpolate( self.pos_embedding, size=patches.shape[0], mode="linear" ) pe = rearrange(pe, "emb_dim 1 tokens -> tokens 1 emb_dim") patches = patches + pe patches = self.layer_norm(self.transformer(patches)) patches = self.classifier(patches) patches = rearrange(patches, "tokens batch emb_dim -> batch tokens emb_dim") return patches
[docs]def positionalencoding1d(d_model, length): """ :param d_model: dimension of the model :param length: length of positions :return: length*d_model position matrix """ if d_model % 2 != 0: raise ValueError( "Cannot use sin/cos positional encoding with " "odd dim (got dim={:d})".format(d_model) ) pe = torch.zeros(length, d_model) position = torch.arange(0, length).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(position.float() * div_term) pe[:, 1::2] = torch.cos(position.float() * div_term) return rearrange(pe, "tokens emb_dim -> emb_dim 1 tokens")