cyto_dl.nn.vits.blocks.masked_unit_attention module#
- class cyto_dl.nn.vits.blocks.masked_unit_attention.HieraBlock(dim: int, dim_out: int, heads: int, spatial_dims: int = 3, mlp_ratio: float = 4.0, drop_path: float = 0.0, norm_layer: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.normalization.LayerNorm'>, act_layer: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.activation.GELU'>, q_stride: ~typing.List[int] = [1, 1, 1], patches_per_mask_unit: ~typing.List[int] = [2, 12, 12])[source]#
Bases:
Module
- Parameters:
dim (int) – Dimension of the input features.
dim_out (int) – Dimension of the output features.
heads (int) – Number of attention heads.
spatial_dims (int, optional) – Number of spatial dimensions, by default 3.
mlp_ratio (float, optional) – Ratio of MLP hidden dim to embedding dim, by default 4.0.
drop_path (float, optional) – Dropout rate for the path, by default 0.0.
norm_layer (nn.Module, optional) – Normalization layer, by default nn.LayerNorm.
act_layer (nn.Module, optional) – Activation layer for the MLP, by default nn.GELU.
q_stride (List[int], optional) – Stride for query, by default [1, 1, 1].
patches_per_mask_unit (List[int], optional) – Number of patches per mask unit, by default [2, 12, 12].
- class cyto_dl.nn.vits.blocks.masked_unit_attention.MaskUnitAttention(dim, dim_out, spatial_dims: int = 3, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, q_stride=[1, 1, 1], patches_per_mask_unit=[2, 12, 12])[source]#
Bases:
Module
- Parameters:
dim (int) – Dimension of the input features.
dim_out (int) – Dimension of the output features.
spatial_dims (int, optional) – Number of spatial dimensions, by default 3.
num_heads (int, optional) – Number of attention heads, by default 8.
qkv_bias (bool, optional) – If True, add a learnable bias to query, key, value, by default False.
attn_drop (float, optional) – Dropout rate for attention, by default 0.0.
proj_drop (float, optional) – Dropout rate for projection, by default 0.0.
q_stride (List[int], optional) – Stride for query, by default [1, 1, 1].
patches_per_mask_unit (List[int], optional) – Number of patches per mask unit, by default [2, 12, 12].