Source code for cyto_dl.nn.losses.continuous_bernoulli
from torch.distributions.continuous_bernoulli import ContinuousBernoulli
from torch.nn.modules.loss import _Loss as Loss
[docs]class CBLogLoss(Loss):
"""Continuous Bernoulli loss, proposed here:
https://arxiv.org/abs/1907.06845.
"""
def __init__(self, reduction="mean", mode="probs"):
super().__init__(None, None, reduction)
self.mode = mode
[docs] def forward(self, input, target):
# the trick with the dictionary allows us to use either `probs` or `logits`
log_probs = ContinuousBernoulli(**{self.mode: input}).log_prob(target)
# sum per input-element log loss
log_probs = log_probs.view(log_probs.shape[0], -1).sum(axis=1)
# reduce across batch dimension
if self.reduction == "none":
return -log_probs
elif self.reduction == "sum":
return -log_probs.sum()
elif self.reduction == "mean":
return -log_probs.mean()
else:
raise NotImplementedError(f"Unavailable reduction type: {self.reduction}")