cyto_dl.nn.head.base_head module#

class cyto_dl.nn.head.base_head.BaseHead(loss, postprocess={'input': <function detach>, 'prediction': <function detach>})[source]#

Bases: ABC, Module

Base class for task heads.

Parameters:
  • loss – Loss function for task

  • postprocess={“input” (detach, “prediction”: detach}) – Postprocessing for input and predictions of head

forward(x)[source]#
generate_io_map(input_filenames)[source]#

Generates map between input files and output files for a head.

Only used for prediction

run_head(backbone_features, batch, stage, n_postprocess=1, run_forward=True, y_hat=None)[source]#

Run head on backbone features, calculate loss, postprocess and save image, and calculate metrics.

update_params(params)[source]#