import torch.nn as nn
import torch.nn.functional as F
from timm.layers import DropPath
from timm.models.vision_transformer import Block
# from https://github.com/TonyLianLong/CrossMAE/blob/main/transformer_utils.py
[docs]class Mlp(nn.Module):
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
[docs] def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
[docs]class CrossAttention(nn.Module):
def __init__(
self,
encoder_dim,
decoder_dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = decoder_dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.q = nn.Linear(decoder_dim, decoder_dim, bias=qkv_bias)
self.kv = nn.Linear(encoder_dim, decoder_dim * 2, bias=qkv_bias)
self.attn_drop = attn_drop
self.proj = nn.Linear(decoder_dim, decoder_dim)
self.proj_drop = nn.Dropout(proj_drop)
[docs] def forward(self, x, y):
"""Query from decoder (x), key and value from encoder (y)"""
B, N, C = x.shape
Ny = y.shape[1]
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
kv = (
self.kv(y)
.reshape(B, Ny, 2, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
k, v = kv[0], kv[1]
attn = (
F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop,
)
.transpose(1, 2)
.reshape(B, N, C)
)
x = self.proj(attn)
x = self.proj_drop(x)
return x
[docs]class CrossAttentionBlock(nn.Module):
def __init__(
self,
encoder_dim,
decoder_dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(decoder_dim)
self.cross_attn = CrossAttention(
encoder_dim,
decoder_dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(decoder_dim)
mlp_hidden_dim = int(decoder_dim * mlp_ratio)
self.mlp = Mlp(
in_features=decoder_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
)
[docs] def forward(self, x, y):
"""
x: decoder feature; y: encoder feature (after layernorm)
"""
x = x + self.drop_path(self.cross_attn(self.norm1(x), y))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
[docs]class CrossSelfBlock(nn.Module):
def __init__(
self,
emb_dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.x_attn_block = CrossAttentionBlock(
emb_dim,
emb_dim,
num_heads,
mlp_ratio,
qkv_bias,
qk_scale,
drop,
attn_drop,
drop_path,
act_layer,
norm_layer,
)
self.self_attn_block = Block(dim=emb_dim, num_heads=num_heads)
[docs] def forward(self, x, y):
"""
x: decoder feature; y: encoder feature
"""
x = self.x_attn_block(x, y)
x = self.self_attn_block(x)
return x