Source code for cyto_dl.nn.losses.vic_reg

# modified from https://github.com/facebookresearch/vicreg/blob/main/main_vicreg.py

import torch
import torch.nn.functional as F


[docs]def off_diagonal(x): n, m = x.shape assert n == m return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
[docs]class VICRegLoss(torch.nn.Module): def __init__(self, num_features, sim_coeff=5, std_coeff=5, cov_coeff=1): super().__init__() self.num_features = num_features self.sim_coeff = sim_coeff self.std_coeff = std_coeff self.cov_coeff = cov_coeff
[docs] def forward(self, x, y): batch_size, num_features = x.shape # view invarance loss repr_loss = F.mse_loss(x, y) x = x - x.mean(dim=0) y = y - y.mean(dim=0) # variance loss std_x = torch.sqrt(x.var(dim=0) + 0.0001) std_y = torch.sqrt(y.var(dim=0) + 0.0001) std_loss = torch.mean(F.relu(1 - std_x)) / 2 + torch.mean(F.relu(1 - std_y)) / 2 # covariance loss cov_x = (x.T @ x) / (batch_size - 1) cov_y = (y.T @ y) / (batch_size - 1) cov_loss = off_diagonal(cov_x).pow_(2).sum().div(num_features) + off_diagonal(cov_y).pow_( 2 ).sum().div(num_features) loss = self.sim_coeff * repr_loss + self.std_coeff * std_loss + self.cov_coeff * cov_loss return loss