cyto_dl.nn.vits.blocks.cross_attention module#
- class cyto_dl.nn.vits.blocks.cross_attention.CrossAttention(encoder_dim, decoder_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]#
Bases:
Module
- class cyto_dl.nn.vits.blocks.cross_attention.CrossAttentionBlock(encoder_dim, decoder_dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#
Bases:
Module
- class cyto_dl.nn.vits.blocks.cross_attention.CrossSelfBlock(emb_dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#
Bases:
Module