Source code for fnet.losses

"""Loss functions for fnet models."""


from typing import Optional

import torch


[docs]class HeteroscedasticLoss(torch.nn.Module): """Loss function to capture heteroscedastic aleatoric uncertainty."""
[docs] def forward(self, y_hat_batch: torch.Tensor, y_batch: torch.Tensor): """Calculates loss. Parameters ---------- y_hat_batch Batched, 2-channel model output. y_batch Batched, 1-channel target output. """ mean_batch = y_hat_batch[:, 0:1, :, :, :] log_var_batch = y_hat_batch[:, 1:2, :, :, :] loss_batch = ( 0.5 * torch.exp(-log_var_batch) * (mean_batch - y_batch).pow(2) + 0.5 * log_var_batch ) return loss_batch.mean()
[docs]class WeightedMSE(torch.nn.Module): """Criterion for weighted mean-squared error."""
[docs] def forward( self, y_hat_batch: torch.Tensor, y_batch: torch.Tensor, weight_map_batch: Optional[torch.Tensor] = None, ): """Calculates weighted MSE. Parameters ---------- y_hat_batch Batched prediction. y_batch Batched target. weight_map_batch Optional weight map. """ if weight_map_batch is None: return torch.nn.functional.mse_loss(y_hat_batch, y_batch) dim = tuple(range(1, len(weight_map_batch.size()))) return (weight_map_batch * (y_hat_batch - y_batch) ** 2).sum(dim=dim).mean()