Source code for cyto_dl.nn.losses.adversarial_loss

from torch.nn.modules.loss import _Loss as Loss


[docs]class AdversarialLoss(Loss): def __init__(self, discriminator, loss, argmax=False, reduction="mean", squeeze=False): super().__init__(None, None, reduction) self.discriminator = discriminator self.loss = loss self.argmax = argmax self.squeeze = squeeze
[docs] def forward(self, input, target, return_pred=False): yhat = self.discriminator(input) if self.squeeze: loss = self.loss(yhat, target.squeeze()) elif self.argmax: loss = self.loss(yhat, target.argmax(1)) else: loss = self.loss(yhat, target) # reduce across batch dimension if self.reduction == "none": pass elif self.reduction == "sum": loss = loss.sum() elif self.reduction == "mean": loss = loss.mean() else: raise NotImplementedError(f"Unavailable reduction type: {self.reduction}") if return_pred: return loss, yhat return loss