# modified from https://github.com/IcarusWizard/MAE/blob/main/model.py#L124
# inspired by https://github.com/facebookresearch/hiera
from typing import Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional
from einops import rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block
from cyto_dl.nn.vits.blocks import IntermediateWeigher, Patchify
from cyto_dl.nn.vits.blocks.masked_unit_attention import HieraBlock
from cyto_dl.nn.vits.blocks.patchify import PatchifyHiera
from cyto_dl.nn.vits.utils import match_tuple_dimensions
[docs]class MAE_Encoder(torch.nn.Module):
def __init__(
self,
num_patches: List[int],
spatial_dims: int = 3,
patch_size: Union[int, List[int]] = 4,
emb_dim: Optional[int] = 192,
num_layer: Optional[int] = 12,
num_head: Optional[int] = 3,
context_pixels: Optional[Union[int, List[int]]] = 0,
input_channels: Optional[int] = 1,
n_intermediate_weights: Optional[int] = -1,
) -> None:
"""
Parameters
----------
num_patches: List[int], int
Number of patches in each dimension. If a single int is provided, the number of patches in each dimension will be the same.
spatial_dims: int
Number of spatial dimensions
patch_size: List[int]
Size of each patch
emb_dim: int
Dimension of embedding
num_layer: int
Number of transformer layers
num_head: int
Number of heads in transformer
context_pixels: List[int], int
Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. If a single int is provided, the number of context pixels in each dimension will be the same.
input_channels: int
Number of input channels
n_intermediate_weights: bool
Whether to use intermediate weights for weighted sum of intermediate layers
"""
super().__init__()
num_patches, patch_size, context_pixels = match_tuple_dimensions(
spatial_dims, [num_patches, patch_size, context_pixels]
)
self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.patchify = Patchify(
patch_size, emb_dim, num_patches, spatial_dims, context_pixels, input_channels
)
weight_intermediates = n_intermediate_weights > 0
if weight_intermediates:
self.transformer = torch.nn.ModuleList(
[Block(emb_dim, num_head) for _ in range(num_layer)]
)
else:
self.transformer = torch.nn.Sequential(
*[Block(emb_dim, num_head) for _ in range(num_layer)]
)
self.layer_norm = torch.nn.LayerNorm(emb_dim)
self.intermediate_weighter = (
IntermediateWeigher(num_layer, emb_dim, n_intermediate_weights)
if weight_intermediates
else None
)
self.init_weight()
[docs] def init_weight(self):
trunc_normal_(self.cls_token, std=0.02)
[docs] def forward(self, img, mask_ratio=0.75):
patches, mask, forward_indexes, backward_indexes = self.patchify(img, mask_ratio)
patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
patches = rearrange(patches, "t b c -> b t c")
if self.intermediate_weighter is not None:
intermediates = [patches]
for block in self.transformer:
patches = block(patches)
intermediates.append(patches)
features = self.layer_norm(self.intermediate_weighter(intermediates))
features = rearrange(features, "n b t c -> n t b c")
else:
features = self.layer_norm(self.transformer(patches))
features = rearrange(features, "b t c -> t b c")
if mask_ratio > 0:
return features, mask, forward_indexes, backward_indexes
return features
[docs]class JEPAEncoder(torch.nn.Module):
def __init__(
self,
num_patches: Union[int, List[int]],
spatial_dims: int = 3,
patch_size: Union[int, List[int]] = 4,
emb_dim: Optional[int] = 192,
num_layer: Optional[int] = 12,
num_head: Optional[int] = 3,
context_pixels: Optional[Union[int, List[int]]] = 0,
input_channels: Optional[int] = 1,
learnable_pos_embedding: Optional[bool] = True,
) -> None:
"""
Parameters
----------
num_patches: List[int], int
Number of patches in each dimension. If a single int is provided, the number of patches in each dimension will be the same.
spatial_dims: int
Number of spatial dimensions
patch_size: List[int]
Size of each patch
emb_dim: int
Dimension of embedding
num_layer: int
Number of transformer layers
num_head: int
Number of heads in transformer
context_pixels: List[int], int
Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. If a single int is provided, the number of context pixels in each dimension will be the same.
input_channels: int
Number of input channels
learnable_pos_embedding: bool
If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images.
"""
super().__init__()
num_patches, patch_size, context_pixels = match_tuple_dimensions(
spatial_dims, [num_patches, patch_size, context_pixels]
)
self.patchify = Patchify(
patch_size=patch_size,
emb_dim=emb_dim,
n_patches=num_patches,
spatial_dims=spatial_dims,
context_pixels=context_pixels,
input_channels=input_channels,
learnable_pos_embedding=learnable_pos_embedding,
)
self.transformer = torch.nn.Sequential(
*[Block(emb_dim, num_head) for _ in range(num_layer)]
)
self.layer_norm = torch.nn.LayerNorm(emb_dim)
[docs] def forward(self, img, patchify=True):
if patchify:
patches, _, _, _ = self.patchify(img, mask_ratio=0)
else:
patches = img
patches = rearrange(patches, "t b c -> b t c")
features = self.layer_norm(self.transformer(patches))
return features
[docs]class SpatialMerger(nn.Module):
"""Class for converting multi-resolution Hiera features to the same (lowest) spatial resolution
via convolution."""
def __init__(
self, downsample_factor: List[int], in_dim: int, out_dim: int, spatial_dims: int = 3
):
super().__init__()
downsample_factor = match_tuple_dimensions(spatial_dims, [downsample_factor])[0]
self.spatial_dims = spatial_dims
conv_fn = nn.Conv3d if spatial_dims == 3 else nn.Conv2d
conv = conv_fn(
in_channels=in_dim,
out_channels=out_dim,
kernel_size=downsample_factor,
stride=downsample_factor,
padding=0,
bias=False,
)
if spatial_dims == 3:
tokens2img = Rearrange(
"b n_mu (z y x) c -> (b n_mu) c z y x",
z=downsample_factor[0],
y=downsample_factor[1],
x=downsample_factor[2],
)
else:
tokens2img = Rearrange(
"b n_mu (y x) c -> (b n_mu) c y x",
y=downsample_factor[0],
x=downsample_factor[1],
)
self.model = nn.Sequential(tokens2img, conv)
[docs] def forward(self, x):
b, n_mu, _, _ = x.shape
x = self.model(x)
if self.spatial_dims == 3:
x = rearrange(x, "(b n_mu) c z y x -> b n_mu (z y x) c", b=b, n_mu=n_mu)
else:
x = rearrange(x, "(b n_mu) c y x -> b n_mu (y x) c", b=b, n_mu=n_mu)
return x
[docs]class HieraEncoder(torch.nn.Module):
def __init__(
self,
num_patches: Union[int, List[int]],
num_mask_units: Union[int, List[int]],
architecture: List[Dict],
emb_dim: int = 64,
spatial_dims: int = 3,
patch_size: Union[int, List[int]] = 4,
context_pixels: Optional[Union[int, List[int]]] = 0,
input_channels: Optional[int] = 1,
save_layers: Optional[bool] = False,
) -> None:
"""
Parameters
----------
num_patches: int, List[int]
Number of patches in each dimension. If a single int is provided, the number of patches in each dimension will be the same.
num_mask_units: int, List[int]
Number of mask units in each dimension. If a single int is provided, the number of mask units in each dimension will be the same.
architecture: List[Dict]
List of dictionaries specifying the architecture of the transformer. Each dictionary should have the following keys:
- repeat: int
Number of times to repeat the block
- num_heads: int
Number of heads in the multihead attention
- q_stride: int, List[int]
Stride for the query in each spatial dimension
- self_attention: bool
Whether to use self attention or mask unit attention
On the last repeat of each non-self-attention block, the embedding dimension is doubled and spatial pooling with `q_stride` is performed within each mask unit. For example, a block with a embed_dim=4, q_stride=2, and repeat=2, the first repeat just does mask unit attention, while the second will produce an 8-dimensional output that has been spatially pooled.
emb_dim: int
Dimension of embedding
spatial_dims: int
Number of spatial dimensions
patch_size: List[int]
Size of each patch
context_pixels: List[int]
Number of extra pixels around each patch to include in convolutional embedding to encoder dimension.
input_channels: int
Number of input channels
save_layers: bool
Whether to save the intermediate layer outputs
"""
super().__init__()
num_patches, num_mask_units, patch_size, context_pixels = match_tuple_dimensions(
spatial_dims, [num_patches, num_mask_units, patch_size, context_pixels]
)
# make sure q stride shape matches spatial dims
for i in range(len(architecture)):
if "q_stride" in architecture[i]:
architecture[i]["q_stride"] = match_tuple_dimensions(
spatial_dims, [architecture[i]["q_stride"]]
)[0]
self.save_layers = save_layers
self.patchify = PatchifyHiera(
patch_size=patch_size,
n_patches=num_patches,
emb_dim=emb_dim,
spatial_dims=spatial_dims,
context_pixels=context_pixels,
input_channels=input_channels,
mask_units_per_dim=num_mask_units,
)
patches_per_mask_unit = np.array(num_patches) // np.array(num_mask_units)
total_downsampling_per_axis = np.prod(
[block.get("q_stride", [1] * spatial_dims) for block in architecture], axis=0
)
assert np.all(
patches_per_mask_unit - total_downsampling_per_axis >= 0
), f"Number of mask units must be greater than the total downsampling ratio, got {patches_per_mask_unit} patches per mask unit and {total_downsampling_per_axis} total downsampling ratio. Please adjust your q_stride or increase the number of patches per mask unit."
assert np.all(
patches_per_mask_unit % total_downsampling_per_axis == 0
), f"Number of mask units must be divisible by the total downsampling ratio, got {patches_per_mask_unit} patches per mask unit and {total_downsampling_per_axis} total downsampling ratio. Please adjust your q_stride"
# number of output features doubles in each masked unit attention block, stays constant during self attention blocks
self.final_dim = emb_dim * (2 ** len(architecture))
self.save_block_idxs = []
self.save_block_dims = []
self.spatial_mergers = torch.nn.ParameterDict({})
transformer = []
num_blocks = 0
for stage_num, stage in enumerate(architecture):
# use mask unit attention until first layer that uses self attention
if stage.get("self_attention", False):
break
print(f"Stage: {stage_num}")
for block in range(stage["repeat"]):
is_last = block == stage["repeat"] - 1
# do spatial pooling within mask unit on last block of stage
q_stride = stage["q_stride"] if is_last else [1] * spatial_dims
# double embedding dimension in last block of stage
dim_in = emb_dim * (2**stage_num)
dim_out = dim_in if not is_last else dim_in * 2
print(
f"\tBlock {block}:\t\tdim_in: {dim_in}, dim_out: {dim_out}, num_heads: {stage['num_heads']}, q_stride: {q_stride}, patches_per_mask_unit: {patches_per_mask_unit}"
)
transformer.append(
HieraBlock(
dim=dim_in,
dim_out=dim_out,
heads=stage["num_heads"],
spatial_dims=spatial_dims,
q_stride=q_stride,
patches_per_mask_unit=patches_per_mask_unit,
)
)
if is_last:
# save the block before the spatial pooling unless it's the final stage
save_block = (
num_blocks - 1 if stage_num < len(architecture) - 1 else num_blocks
)
self.save_block_idxs.append(save_block)
self.save_block_dims.append(dim_in)
# create a spatial merger for combining tokens pre-downsampling, last stage doesn't need merging since it has expected num channels, spatial shape
self.spatial_mergers[f"block_{save_block}"] = (
SpatialMerger(
patches_per_mask_unit,
dim_in,
self.final_dim,
spatial_dims=spatial_dims,
)
if stage_num < len(architecture) - 1
else torch.nn.Identity()
)
# at end of each layer, patches per mask unit is reduced as we pool spatially within mask units
patches_per_mask_unit = patches_per_mask_unit // np.array(stage["q_stride"])
num_blocks += 1
self.mask_unit_transformer = torch.nn.Sequential(*transformer)
self.save_block_dims.append(self.final_dim)
self.save_block_dims.reverse()
self.self_attention_transformer = torch.nn.Sequential(
*[Block(self.final_dim, stage["num_heads"]) for _ in range(stage["repeat"])]
)
self.layer_norm = torch.nn.LayerNorm(self.final_dim)
[docs] def forward(self, img, mask_ratio):
patches, mask, forward_indexes, backward_indexes = self.patchify(img, mask_ratio)
# mask unit attention
mask_unit_embeddings = 0.0
save_layers = []
for i, block in enumerate(self.mask_unit_transformer):
patches = block(patches)
if i in self.save_block_idxs:
mask_unit_embeddings += self.spatial_mergers[f"block_{i}"](patches)
if self.save_layers:
save_layers.append(patches)
# combine mask units and tokens for full self attention transformer
mask_unit_embeddings = rearrange(mask_unit_embeddings, "b n_mu t c -> b (n_mu t) c")
mask_unit_embeddings = self.self_attention_transformer(mask_unit_embeddings)
mask_unit_embeddings = self.layer_norm(mask_unit_embeddings)
mask_unit_embeddings = rearrange(mask_unit_embeddings, "b t c -> t b c")
return mask_unit_embeddings, mask, forward_indexes, backward_indexes # , save_layers