cyto_dl.nn.vits.blocks.intermediate_weigher module# class cyto_dl.nn.vits.blocks.intermediate_weigher.IntermediateWeigher(num_layers, embed_dim, n_outputs, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]# Bases: Module forward(x)[source]# Apply layer norm to each intermediate feature and return n_outputs weighted sums, last dimension is n_outputs.