Source code for cyto_dl.nn.vits.blocks.masked_unit_attention

# inspired by https://github.com/facebookresearch/hiera/tree/main
from typing import List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce
from einops.layers.torch import Reduce
from timm.models.layers import DropPath, Mlp

from cyto_dl.nn.vits.utils import match_tuple_dimensions


[docs]class MaskUnitAttention(torch.nn.Module): def __init__( self, dim, dim_out, spatial_dims: int = 3, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, q_stride=[1, 1, 1], patches_per_mask_unit=[2, 12, 12], ): """ Parameters ---------- dim : int Dimension of the input features. dim_out : int Dimension of the output features. spatial_dims : int, optional Number of spatial dimensions, by default 3. num_heads : int, optional Number of attention heads, by default 8. qkv_bias : bool, optional If True, add a learnable bias to query, key, value, by default False. attn_drop : float, optional Dropout rate for attention, by default 0.0. proj_drop : float, optional Dropout rate for projection, by default 0.0. q_stride : List[int], optional Stride for query, by default [1, 1, 1]. patches_per_mask_unit : List[int], optional Number of patches per mask unit, by default [2, 12, 12]. """ super().__init__() q_stride, patches_per_mask_unit = match_tuple_dimensions( spatial_dims, [q_stride, patches_per_mask_unit] ) self.spatial_dims = spatial_dims self.num_heads = num_heads self.head_dim = dim_out // num_heads self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias) self.attn_drop = attn_drop self.proj = nn.Linear(dim_out, dim_out) self.proj_drop = nn.Dropout(proj_drop) self.dim_out = dim_out self.q_stride = np.array(q_stride) self.pooled_patches_per_mask_unit = ( np.array(patches_per_mask_unit) / self.q_stride ).astype(int)
[docs] def forward(self, x): # project and split into q,k,v embeddings qkv = rearrange( self.qkv(x), "batch num_mask_units tokens_per_mask_unit (head_dim num_heads qkv) -> qkv batch num_mask_units num_heads tokens_per_mask_unit head_dim", head_dim=self.head_dim, qkv=3, num_heads=self.num_heads, ) q, k, v = qkv[0], qkv[1], qkv[2] if np.any(self.q_stride > 1): # within a mask unit, tokens are spatially ordered # perform spatial 2x2x2 max pooling over tokens if self.spatial_dims == 3: q = reduce( q, "batch num_mask_units num_heads (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c -> batch num_mask_units num_heads (n_patches_z n_patches_y n_patches_x) c", reduction="max", q_stride_z=self.q_stride[0], q_stride_y=self.q_stride[1], q_stride_x=self.q_stride[2], n_patches_z=self.pooled_patches_per_mask_unit[0], n_patches_y=self.pooled_patches_per_mask_unit[1], n_patches_x=self.pooled_patches_per_mask_unit[2], ) elif self.spatial_dims == 2: q = reduce( q, "batch num_mask_units num_heads (n_patches_y q_stride_y n_patches_x q_stride_x) c ->batch num_mask_units num_heads (n_patches_y n_patches_x) c", reduction="max", q_stride_y=self.q_stride[0], q_stride_x=self.q_stride[1], n_patches_y=self.pooled_patches_per_mask_unit[0], n_patches_x=self.pooled_patches_per_mask_unit[1], ) attn = F.scaled_dot_product_attention( q, k, v, dropout_p=self.attn_drop, ) # combine heads into single channel dimension x = rearrange( attn, "b mask_units n_heads t c -> b mask_units t (n_heads c)", n_heads=self.num_heads ) x = self.proj(x) x = self.proj_drop(x) return x
[docs]class HieraBlock(nn.Module): def __init__( self, dim: int, dim_out: int, heads: int, spatial_dims: int = 3, mlp_ratio: float = 4.0, drop_path: float = 0.0, norm_layer: nn.Module = nn.LayerNorm, act_layer: nn.Module = nn.GELU, q_stride: List[int] = [1, 1, 1], patches_per_mask_unit: List[int] = [2, 12, 12], ): """ Parameters ---------- dim : int Dimension of the input features. dim_out : int Dimension of the output features. heads : int Number of attention heads. spatial_dims : int, optional Number of spatial dimensions, by default 3. mlp_ratio : float, optional Ratio of MLP hidden dim to embedding dim, by default 4.0. drop_path : float, optional Dropout rate for the path, by default 0.0. norm_layer : nn.Module, optional Normalization layer, by default nn.LayerNorm. act_layer : nn.Module, optional Activation layer for the MLP, by default nn.GELU. q_stride : List[int], optional Stride for query, by default [1, 1, 1]. patches_per_mask_unit : List[int], optional Number of patches per mask unit, by default [2, 12, 12]. """ super().__init__() patches_per_mask_unit, q_stride = match_tuple_dimensions( spatial_dims, [patches_per_mask_unit, q_stride] ) self.spatial_dims = spatial_dims self.dim = dim self.dim_out = dim_out self.q_stride = q_stride self.norm1 = norm_layer(dim) do_pool = np.any(np.array(q_stride) > 1) or dim != dim_out self.attn = MaskUnitAttention( dim, dim_out, spatial_dims=spatial_dims, num_heads=heads, q_stride=q_stride, patches_per_mask_unit=patches_per_mask_unit, ) self.norm2 = norm_layer(dim_out) self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer) self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() # mean pooling by q stride within a mask unit if self.spatial_dims == 3: skip_connection_pooling = Reduce( "b n (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c -> b n (n_patches_z n_patches_y n_patches_x) c", reduction="mean", q_stride_z=self.q_stride[0], q_stride_y=self.q_stride[1], q_stride_x=self.q_stride[2], n_patches_z=self.attn.pooled_patches_per_mask_unit[0], n_patches_y=self.attn.pooled_patches_per_mask_unit[1], n_patches_x=self.attn.pooled_patches_per_mask_unit[2], ) elif self.spatial_dims == 2: skip_connection_pooling = Reduce( "b n (n_patches_y q_stride_y n_patches_x q_stride_x) c -> b n (n_patches_y n_patches_x) c", reduction="mean", q_stride_y=self.q_stride[0], q_stride_x=self.q_stride[1], n_patches_y=self.attn.pooled_patches_per_mask_unit[0], n_patches_x=self.attn.pooled_patches_per_mask_unit[1], ) self.proj = ( torch.nn.Sequential(skip_connection_pooling, nn.Linear(dim, dim_out)) if do_pool else torch.nn.Identity() )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: batch x mask units x tokens x emb_dim """ # Attention + Q Pooling x_norm = self.norm1(x) # change dimension and subsample within mask unit for skip connection x = self.proj(x_norm) x = x + self.drop_path(self.attn(x_norm)) # MLP x = x + self.drop_path(self.mlp(self.norm2(x))) return x