cyto_dl.nn.vits.blocks.patchify.patchify_hiera module#

class cyto_dl.nn.vits.blocks.patchify.patchify_hiera.PatchifyHiera(patch_size: List[int], n_patches: List[int], emb_dim: int = 64, spatial_dims: int = 3, context_pixels: List[int] = [0, 0, 0], input_channels: int = 1, tasks: List[str] | None = [], mask_units_per_dim: List[int] = [8, 8, 8])[source]#

Bases: PatchifyBase

Class for converting images to a sequence of patches with positional embeddings, masked at the level of mask units (groups of patches specified by mask_units_per_dim).

patch_size: List[int]

Size of each patch in pix (ZYX order for 3D, YX order for 2D)

n_patches: List[int]

Number of patches in each spatial dimension (ZYX order for 3D, YX order for 2D)

emb_dim: int

Dimension of encoder

spatial_dims: int

Number of spatial dimensions

context_pixels: List[int]

Number of extra pixels around each patch to include in convolutional embedding to encoder dimension.

input_channels: int

Number of input channels

tasks: List[str]

List of tasks to encode

mask_units_per_dim: List[int]

Number of mask units in each spatial dimension (ZYX order for 3D, YX order for 2D)

create_img2token(mask_units_per_dim)[source]#
extract_visible_tokens(tokens, forward_indexes, n_visible_patches)[source]#
get_mask_args(mask_ratio)[source]#
property img2token#
cyto_dl.nn.vits.blocks.patchify.patchify_hiera.take_indexes_mask(sequences, indexes)[source]#

sequences: batch x mask units x patches x emb_dim indexes: mask_units x batch