Source code for cyto_dl.nn.vits.blocks.cross_attention

import torch.nn as nn
import torch.nn.functional as F
from timm.layers import DropPath
from timm.models.vision_transformer import Block

# from https://github.com/TonyLianLong/CrossMAE/blob/main/transformer_utils.py


[docs]class Mlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop)
[docs] def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x
[docs]class CrossAttention(nn.Module): def __init__( self, encoder_dim, decoder_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, ): super().__init__() self.num_heads = num_heads head_dim = decoder_dim // num_heads self.scale = qk_scale or head_dim**-0.5 self.q = nn.Linear(decoder_dim, decoder_dim, bias=qkv_bias) self.kv = nn.Linear(encoder_dim, decoder_dim * 2, bias=qkv_bias) self.attn_drop = attn_drop self.proj = nn.Linear(decoder_dim, decoder_dim) self.proj_drop = nn.Dropout(proj_drop)
[docs] def forward(self, x, y): """query from decoder (x), key and value from encoder (y)""" B, N, C = x.shape Ny = y.shape[1] q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) kv = ( self.kv(y) .reshape(B, Ny, 2, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) k, v = kv[0], kv[1] attn = F.scaled_dot_product_attention( q, k, v, dropout_p=self.attn_drop, ) x = attn.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x
[docs]class CrossAttentionBlock(nn.Module): def __init__( self, encoder_dim, decoder_dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): super().__init__() self.norm1 = norm_layer(decoder_dim) self.cross_attn = CrossAttention( encoder_dim, decoder_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(decoder_dim) mlp_hidden_dim = int(decoder_dim * mlp_ratio) self.mlp = Mlp( in_features=decoder_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop )
[docs] def forward(self, x, y): """ x: decoder feature; y: encoder feature (after layernorm) """ x = x + self.drop_path(self.cross_attn(self.norm1(x), y)) x = x + self.drop_path(self.mlp(self.norm2(x))) return x
[docs]class CrossSelfBlock(nn.Module): def __init__( self, emb_dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): super().__init__() self.x_attn_block = CrossAttentionBlock( emb_dim, emb_dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, drop_path, act_layer, norm_layer, ) self.self_attn_block = Block(dim=emb_dim, num_heads=num_heads)
[docs] def forward(self, x, y): """ x: decoder feature; y: encoder feature """ x = self.x_attn_block(x, y) x = self.self_attn_block(x) return x