cyto_dl.nn.vits.blocks.masked_unit_attention module#

class cyto_dl.nn.vits.blocks.masked_unit_attention.HieraBlock(dim: int, dim_out: int, heads: int, spatial_dims: int = 3, mlp_ratio: float = 4.0, drop_path: float = 0.0, norm_layer: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.normalization.LayerNorm'>, act_layer: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.activation.GELU'>, q_stride: ~typing.List[int] = [1, 1, 1], patches_per_mask_unit: ~typing.List[int] = [2, 12, 12])[source]#

Bases: Module

Parameters:
  • dim (int) – Dimension of the input features.

  • dim_out (int) – Dimension of the output features.

  • heads (int) – Number of attention heads.

  • spatial_dims (int, optional) – Number of spatial dimensions, by default 3.

  • mlp_ratio (float, optional) – Ratio of MLP hidden dim to embedding dim, by default 4.0.

  • drop_path (float, optional) – Dropout rate for the path, by default 0.0.

  • norm_layer (nn.Module, optional) – Normalization layer, by default nn.LayerNorm.

  • act_layer (nn.Module, optional) – Activation layer for the MLP, by default nn.GELU.

  • q_stride (List[int], optional) – Stride for query, by default [1, 1, 1].

  • patches_per_mask_unit (List[int], optional) – Number of patches per mask unit, by default [2, 12, 12].

forward(x: Tensor) Tensor[source]#

x: batch x mask units x tokens x emb_dim

class cyto_dl.nn.vits.blocks.masked_unit_attention.MaskUnitAttention(dim, dim_out, spatial_dims: int = 3, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, q_stride=[1, 1, 1], patches_per_mask_unit=[2, 12, 12])[source]#

Bases: Module

Parameters:
  • dim (int) – Dimension of the input features.

  • dim_out (int) – Dimension of the output features.

  • spatial_dims (int, optional) – Number of spatial dimensions, by default 3.

  • num_heads (int, optional) – Number of attention heads, by default 8.

  • qkv_bias (bool, optional) – If True, add a learnable bias to query, key, value, by default False.

  • attn_drop (float, optional) – Dropout rate for attention, by default 0.0.

  • proj_drop (float, optional) – Dropout rate for projection, by default 0.0.

  • q_stride (List[int], optional) – Stride for query, by default [1, 1, 1].

  • patches_per_mask_unit (List[int], optional) – Number of patches per mask unit, by default [2, 12, 12].

forward(x)[source]#