Source code for cyto_dl.nn.losses.weibull
import torch
from torch.distributions import Weibull
from torch.nn.modules.loss import _Loss as Loss
[docs]def weibull_log_probs(beta, alpha, target, eps):
alpha = alpha + eps
return (
torch.log(beta)
- torch.log(alpha)
+ (beta - 1) * (torch.log(target) - torch.log(alpha))
- beta * (target / alpha)
)
[docs]class WeibullLogLoss(Loss):
def __init__(self, reduction="sum", mode="explicit", eps=1e-10):
super().__init__(None, None, reduction)
self.mode = mode
self.eps = eps
[docs] def forward(self, a, b, target):
if self.mode == "explicit":
log_probs = weibull_log_probs(b, a, target, self.eps)
else:
log_probs = Weibull(a, b).log_prob(target)
# reduce across batch dimension
if self.reduction == "none":
return -log_probs
elif self.reduction == "sum":
return -log_probs.sum()
elif self.reduction == "mean":
return -log_probs.mean()
else:
raise NotImplementedError(f"Unavailable reduction type: {self.reduction}")