Source code for cyto_dl.nn.head.vic_reg

from typing import List

from cyto_dl.nn import MLP
from cyto_dl.nn.head import BaseHead


[docs]class VICRegHead(BaseHead): def __init__( self, loss, dims: List[int] = [2048, 8192], hidden_layers: List[int] = [8192, 8192] ): """ Parameters ---------- loss Loss function for task dims input and output dimensions for Projectornetwork hidden_layers hidden layers for MLP """ super().__init__(loss) self.model = MLP(*dims, hidden_layers=hidden_layers)
[docs] def run_head( self, backbone_features, batch, stage, n_postprocess=1, run_forward=True, y_hat=None, ): """Run head on backbone features, calculate loss, postprocess and save image, and calculate metrics.""" x1, x2 = backbone_features x1, x2 = self.forward(x1), self.forward(x2) loss = self._calculate_loss(x1, x2) return { "loss": loss, "y_hat_out": x1, "y_out": x2, }