cyto_dl.nn.vits.blocks.intermediate_weigher module#

class cyto_dl.nn.vits.blocks.intermediate_weigher.IntermediateWeigher(num_layers, embed_dim, n_outputs, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#

Bases: Module

forward(x)[source]#

Apply layer norm to each intermediate feature and return n_outputs weighted sums, last dimension is n_outputs.