cyto_dl.nn.vits.blocks.cross_attention module#

class cyto_dl.nn.vits.blocks.cross_attention.CrossAttention(encoder_dim, decoder_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]#

Bases: Module

forward(x, y)[source]#

query from decoder (x), key and value from encoder (y)

class cyto_dl.nn.vits.blocks.cross_attention.CrossAttentionBlock(encoder_dim, decoder_dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#

Bases: Module

forward(x, y)[source]#

x: decoder feature; y: encoder feature (after layernorm)

class cyto_dl.nn.vits.blocks.cross_attention.CrossSelfBlock(emb_dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#

Bases: Module

forward(x, y)[source]#

x: decoder feature; y: encoder feature

class cyto_dl.nn.vits.blocks.cross_attention.Mlp(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>, drop=0.0)[source]#

Bases: Module

forward(x)[source]#