cyto_dl.nn.head.mae_head module#

class cyto_dl.nn.head.mae_head.MAEHead(loss, postprocess={'input': <function detach>, 'prediction': <function detach>})[source]#

Bases: BaseHead

Parameters:
  • loss – Loss function for task

  • postprocess={“input” (detach, “prediction”: detach}) – Postprocessing for input and predictions of head

run_head(backbone_features, batch, stage, n_postprocess=1, run_forward=True, y_hat=None)[source]#

Run head on backbone features, calculate loss, postprocess and save image, and calculate metrics.