Source code for cyto_dl.nn.losses.cosine_loss

import torch.nn.functional as F
from torch.nn.modules.loss import _Loss as Loss


[docs]class CosineLoss(Loss): def __init__(self, reduction="mean"): super().__init__(None, None, reduction)
[docs] def forward(self, input, target): # sum per input-element log loss loss = 1 - F.cosine_similarity(input, target) # reduce across batch dimension if self.reduction == "none": return loss elif self.reduction == "sum": return loss.sum() elif self.reduction == "mean": return loss.mean() else: raise NotImplementedError(f"Unavailable reduction type: {self.reduction}")