from typing import List, Optional
import numpy as np
from einops.layers.torch import Rearrange
from cyto_dl.nn.vits.blocks.patchify.patchify_base import PatchifyBase
from cyto_dl.nn.vits.utils import take_indexes
[docs]class Patchify(PatchifyBase):
"""Class for converting images to a masked sequence of patches with positional embeddings."""
def __init__(
self,
patch_size: List[int],
emb_dim: int,
n_patches: List[int],
spatial_dims: int = 3,
context_pixels: List[int] = [0, 0, 0],
input_channels: int = 1,
tasks: Optional[List[str]] = [],
learnable_pos_embedding: bool = True,
):
super().__init__(
patch_size=patch_size,
emb_dim=emb_dim,
n_patches=n_patches,
spatial_dims=spatial_dims,
context_pixels=context_pixels,
input_channels=input_channels,
tasks=tasks,
learnable_pos_embedding=learnable_pos_embedding,
)
@property
def img2token(self):
return self.create_img2token()
[docs] def get_mask_args(self, mask_ratio):
num_patches = np.prod(self.n_patches)
n_visible_patches = int(num_patches * (1 - mask_ratio))
return n_visible_patches, num_patches
[docs] def create_img2token(self):
"""Rearranges the image tensor to a sequence of patches."""
if self.spatial_dims == 3:
return Rearrange("b c z y x -> (z y x) b c")
elif self.spatial_dims == 2:
return Rearrange("b c y x -> (y x) b c")