Source code for cyto_dl.nn.losses.weighted_mse_loss

import numpy as np
import torch
from torch import Tensor
from torch.nn import functional as F
from torch.nn.modules.loss import _Loss as Loss


[docs]class WeightedMSELoss(Loss): def __init__(self, reduction="none", weights=1): super().__init__(None, None, reduction) self.reduction = reduction self.weights = torch.tensor(weights).unsqueeze(0)
[docs] def forward(self, input: Tensor, target: Tensor) -> Tensor: loss = F.mse_loss(input, target, reduction="none") * self.weights if self.reduction == "mean": loss = loss.mean(axis=1) elif self.reduction == "sum": loss = loss.sum(axis=1) return loss