Source code for cyto_dl.nn.vits.blocks.intermediate_weigher

import torch
from einops import rearrange


[docs]class IntermediateWeigher(torch.nn.Module): def __init__(self, num_layers, embed_dim, n_outputs, norm_layer=torch.nn.LayerNorm): super().__init__() self.weights = torch.nn.Linear(num_layers, n_outputs) # initialize with equal weighting of all layers self.weights.weight.data.fill_(1.0 / num_layers) self.weights.bias.data.zero_() self.norms = torch.nn.ModuleList([norm_layer(embed_dim) for _ in range(num_layers)])
[docs] def forward(self, x): """Apply layer norm to each intermediate feature and return n_outputs weighted sums, last dimension is n_outputs.""" x = torch.stack([norm(x[i]) for i, norm in enumerate(self.norms)], dim=-1) x = self.weights(x) x = rearrange(x, " b t c n -> n b t c") return x