cyto_dl.nn.head.mask_head module#
- class cyto_dl.nn.head.mask_head.MaskHead(loss, mask_key: str = 'mask', postprocess={'input': <function detach>, 'prediction': <function detach>})[source]#
Bases:
BaseHead
Task Head using a masked loss function.
- Parameters:
loss – Loss function for task
postprocess={“input” (detach, “prediction”: detach}) – Postprocessing for input and predictions of head
calculate_metric=False – Whether to calculate a metric during training. Not used by GAN head.