cyto_dl.nn.head.mask_head module#

class cyto_dl.nn.head.mask_head.MaskHead(loss, mask_key: str = 'mask', postprocess={'input': <function detach>, 'prediction': <function detach>})[source]#

Bases: BaseHead

Task Head using a masked loss function.

Parameters:
  • loss – Loss function for task

  • postprocess={“input” (detach, “prediction”: detach}) – Postprocessing for input and predictions of head

  • calculate_metric=False – Whether to calculate a metric during training. Not used by GAN head.

run_head(backbone_features, batch, stage, n_postprocess, run_forward=True, y_hat=None)[source]#

Run head on backbone features, calculate loss, postprocess and save image, and calculate metrics.