Source code for cyto_dl.nn.losses.vic_reg
# modified from https://github.com/facebookresearch/vicreg/blob/main/main_vicreg.py
import torch
import torch.nn.functional as F
[docs]def off_diagonal(x):
n, m = x.shape
assert n == m
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
[docs]class VICRegLoss(torch.nn.Module):
def __init__(self, num_features, sim_coeff=5, std_coeff=5, cov_coeff=1):
super().__init__()
self.num_features = num_features
self.sim_coeff = sim_coeff
self.std_coeff = std_coeff
self.cov_coeff = cov_coeff
[docs] def forward(self, x, y):
batch_size, num_features = x.shape
# view invarance loss
repr_loss = F.mse_loss(x, y)
x = x - x.mean(dim=0)
y = y - y.mean(dim=0)
# variance loss
std_x = torch.sqrt(x.var(dim=0) + 0.0001)
std_y = torch.sqrt(y.var(dim=0) + 0.0001)
std_loss = torch.mean(F.relu(1 - std_x)) / 2 + torch.mean(F.relu(1 - std_y)) / 2
# covariance loss
cov_x = (x.T @ x) / (batch_size - 1)
cov_y = (y.T @ y) / (batch_size - 1)
cov_loss = off_diagonal(cov_x).pow_(2).sum().div(num_features) + off_diagonal(cov_y).pow_(
2
).sum().div(num_features)
loss = self.sim_coeff * repr_loss + self.std_coeff * std_loss + self.cov_coeff * cov_loss
return loss