Source code for cyto_dl.nn.losses.gaussian_nll_loss
import numpy as np
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from torch.nn.modules.loss import _Loss as Loss
[docs]class GaussianNLLLoss(Loss):
def __init__(self, mean_dims=None, eps=1e-10):
super().__init__(None, None, "none")
self.mean_dims = tuple(mean_dims)
self.eps = 1e-10
[docs] def forward(self, input: Tensor, target: Tensor) -> Tensor:
if self.mean_dims is None:
mean_dims = [_ for _ in range(len(input.shape))]
else:
mean_dims = self.mean_dims
sigma = ((target - input) ** 2).mean(mean_dims, keepdim=True).sqrt()
log_sigma = (sigma + self.eps).log().detach()
loss = (
(
0.5 * torch.pow((target - input) / log_sigma.exp(), 2)
+ log_sigma
+ 0.5 * np.log(2 * np.pi)
)
.reshape(input.shape[0], -1)
.sum(dim=1, keepdim=True)
)
return loss