import torch
from numpy.typing import ArrayLike
from torch import nn
[docs]class LossWrapper(nn.Module):
def __init__(self, loss_fn, channel_weight: ArrayLike, loss_scale: float = 1.0):
"""Loss Wrapper for weighting loss between channels differently. `loss_fn` is calculated
per-channel, scaled by `channel_weight`, averaged, and scaled by `loss_scale`
Parameters
----------
loss_fn
Loss function
channel_weight:ArrayLike
array of floats with length equal to number of channels of predicted image
loss_scale: float
Scale for channel-weighted loss.
"""
super().__init__()
self.loss_fn = loss_fn
self.channel_weight = channel_weight
self.loss_scale = loss_scale
def __call__(self, y_hat, y):
assert (
len(self.channel_weight) == y.shape[1]
), "Channel size mismatch, please adjust channel weights."
# calculate channel-wise loss, preserving NCZYX dims, and scale by channel weight
loss = torch.stack(
[
torch.mul(
self.loss_fn(y_hat[:, i : i + 1], y[:, i : i + 1]),
self.channel_weight[i],
)
for i in range(len(self.channel_weight))
]
).mean()
return loss * self.loss_scale
[docs]class CMAP_loss(nn.Module):
def __init__(self, loss):
"""Loss Wrapper for losses that accept a spatial costmap, differentially emphasizing pixel
losses throughout an image.
Parameters
----------
loss
Loss function. Should provide per-pixel losses.
"""
super().__init__()
self.loss = loss
def __call__(self, y_hat, y, cmap=None):
self.loss = self.loss.to(y_hat.device)
if cmap is None:
return torch.mean(self.loss(y_hat, y.half()))
# 2d head
if len(y_hat.shape) == 4 and len(cmap.shape) == 5:
cmap, _ = torch.max(cmap, dim=2)
return torch.mean(torch.mul(self.loss(y_hat, y.half()), cmap))