cyto_dl.models.im2im.multi_task module#

class cyto_dl.models.im2im.multi_task.MultiTaskIm2Im(*args, **kwargs)[source]#

Bases: BaseModel

Parameters:
  • backbone (nn.Module) – backbone network, parameters are shared between task heads

  • task_heads (Dict) – task-specific heads

  • x_key (str) – key of input image in batch

  • save_dir=”./” – directory to save images during training and validation

  • save_images_every_n_epochs=1 – Frequency to save out images during training

  • inference_args (Dict = {}) – Arguments passed to monai’s [sliding window inferer](https://docs.monai.io/en/stable/inferers.html#sliding-window-inference)

  • inference_heads (Union[List, None] = None) – Optional list of heads to run during inference. Defaults to running all heads.

  • compile (False) – Whether to compile the model using torch.compile

  • **base_kwargs – Additional arguments passed to BaseModel

configure_optimizers()[source]#
forward(x, run_heads)[source]#
get_n_postprocess_image(batch, batch_idx, stage)[source]#
get_per_head(outs, key)[source]#
model_step(stage, batch, batch_idx)[source]#
predict_step(batch, batch_idx)[source]#
run_forward(batch, stage, n_postprocess, run_heads)[source]#
test_step(batch, batch_idx)[source]#
training_step(batch, batch_idx)[source]#
validation_step(batch, batch_idx)[source]#