Source code for cyto_dl.nn.vits.utils

from typing import Sequence

import numpy as np
import torch
from einops import rearrange, repeat
from monai.utils.misc import ensure_tuple_rep
from positional_encodings.torch_encodings import (
    PositionalEncoding2D,
    PositionalEncoding3D,
)
from timm.models.layers import trunc_normal_


[docs]def take_indexes(sequences, indexes): return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1]))
[docs]def random_indexes(size: int, device): forward_indexes = torch.randperm(size, device=device, dtype=torch.long) backward_indexes = torch.argsort(forward_indexes) return forward_indexes, backward_indexes
[docs]def get_positional_embedding( num_patches: Sequence[int], emb_dim: int, use_cls_token: bool = True, learnable: bool = True ): """Generate a positional embedding (with or without a cls token) for a given number of patches and embedding dimension. Can be either learnable or fixed. """ if learnable: pe = torch.nn.Parameter(torch.zeros(np.prod(num_patches) + int(use_cls_token), 1, emb_dim)) trunc_normal_(pe, std=0.02) return pe else: test_tensor = torch.ones(1, *num_patches, emb_dim) if len(num_patches) not in (2, 3): raise ValueError("Only 2d and 3d positional encodings are supported") if len(num_patches) == 2: pe = PositionalEncoding2D(emb_dim)(test_tensor) pe = rearrange(pe, "b y x c -> (y x) b c") elif len(num_patches) == 3: pe = PositionalEncoding3D(emb_dim)(test_tensor) pe = rearrange(pe, "b z y x c -> (z y x) b c") if use_cls_token: cls_token = torch.zeros(1, 1, emb_dim) pe = torch.cat([cls_token, pe], dim=0) return torch.nn.Parameter(pe, requires_grad=False)
[docs]def match_tuple_dimensions(spatial_dims, tuples): """Ensure that each element in a list of tuples has the same length as spatial_dims. If a single element, the element is repeated to match the spatial_dims. """ assert spatial_dims in (2, 3), "spatial_dims must be 2 or 3" return [ensure_tuple_rep(t, spatial_dims) for t in tuples]