Source code for cyto_dl.nn.losses.loss_wrapper

import torch
from numpy.typing import ArrayLike
from torch import nn


[docs]class LossWrapper(nn.Module): def __init__(self, loss_fn, channel_weight: ArrayLike, loss_scale: float = 1.0): """Loss Wrapper for weighting loss between channels differently. `loss_fn` is calculated per-channel, scaled by `channel_weight`, averaged, and scaled by `loss_scale` Parameters ---------- loss_fn Loss function channel_weight:ArrayLike array of floats with length equal to number of channels of predicted image loss_scale: float Scale for channel-weighted loss. """ super().__init__() self.loss_fn = loss_fn self.channel_weight = channel_weight self.loss_scale = loss_scale def __call__(self, y_hat, y): assert ( len(self.channel_weight) == y.shape[1] ), "Channel size mismatch, please adjust channel weights." # calculate channel-wise loss, preserving NCZYX dims, and scale by channel weight loss = torch.stack( [ torch.mul( self.loss_fn(y_hat[:, i : i + 1], y[:, i : i + 1]), self.channel_weight[i], ) for i in range(len(self.channel_weight)) ] ).mean() return loss * self.loss_scale
[docs]class CMAP_loss(nn.Module): def __init__(self, loss): """Loss Wrapper for losses that accept a spatial costmap, differentially emphasizing pixel losses throughout an image. Parameters ---------- loss Loss function. Should provide per-pixel losses. """ super().__init__() self.loss = loss def __call__(self, y_hat, y, cmap=None): self.loss = self.loss.to(y_hat.device) if cmap is None: return torch.mean(self.loss(y_hat, y.half())) # 2d head if len(y_hat.shape) == 4 and len(cmap.shape) == 5: cmap, _ = torch.max(cmap, dim=2) return torch.mean(torch.mul(self.loss(y_hat, y.half()), cmap))