"""Loss functions for fnet models."""
from typing import Optional
import torch
[docs]class HeteroscedasticLoss(torch.nn.Module):
"""Loss function to capture heteroscedastic aleatoric uncertainty."""
[docs] def forward(self, y_hat_batch: torch.Tensor, y_batch: torch.Tensor):
"""Calculates loss.
Parameters
----------
y_hat_batch
Batched, 2-channel model output.
y_batch
Batched, 1-channel target output.
"""
mean_batch = y_hat_batch[:, 0:1, :, :, :]
log_var_batch = y_hat_batch[:, 1:2, :, :, :]
loss_batch = (
0.5 * torch.exp(-log_var_batch) * (mean_batch - y_batch).pow(2)
+ 0.5 * log_var_batch
)
return loss_batch.mean()
[docs]class WeightedMSE(torch.nn.Module):
"""Criterion for weighted mean-squared error."""
[docs] def forward(
self,
y_hat_batch: torch.Tensor,
y_batch: torch.Tensor,
weight_map_batch: Optional[torch.Tensor] = None,
):
"""Calculates weighted MSE.
Parameters
----------
y_hat_batch
Batched prediction.
y_batch
Batched target.
weight_map_batch
Optional weight map.
"""
if weight_map_batch is None:
return torch.nn.functional.mse_loss(y_hat_batch, y_batch)
dim = tuple(range(1, len(weight_map_batch.size())))
return (weight_map_batch * (y_hat_batch - y_batch) ** 2).sum(dim=dim).mean()