Source code for cyto_dl.nn.losses.threshold_loss

import torch
from numpy.typing import ArrayLike
from torch import nn


[docs]class ThresholdLoss(nn.Module): def __init__(self, loss_fn, threshold: float = 0.0, above: bool = True): """Wrapper Loss that thresholds the target before computing the loss given by loss_fn. Parameters ---------- loss_fn Loss function threshold: float = 0.0 Threshold value above: bool = True Whether to threshold above or below """ super().__init__() self.loss_fn = loss_fn self.threshold = threshold self.above = above def __call__(self, input, target): if self.above: target = (target >= self.threshold).type_as(target) else: target = (target <= self.threshold).type_as(target) return self.loss_fn(input, target)