cyto_dl.models.im2im.gan module#

class cyto_dl.models.im2im.gan.GAN(*args, **kwargs)[source]#

Bases: MultiTaskIm2Im

Basic GAN model.

Parameters:
  • backbone (nn.Module) – backbone network, parameters are shared between task heads

  • task_heads (Dict) – task-specific heads

  • discriminator – discriminator network

  • x_key (str) – key of input image in batch

  • save_dir=”./” – directory to save images during training and validation

  • save_images_every_n_epochs=1 – Frequency to save out images during training

  • inference_args (Dict = {}) – Arguments passed to monai’s [sliding window inferer](https://docs.monai.io/en/stable/inferers.html#sliding-window-inference)

  • compile (False) – Whether to compile the model using torch.compile

  • **base_kwargs – Additional arguments passed to BaseModel

configure_optimizers()[source]#
model_step(stage, batch, batch_idx)[source]#
predict_step(batch, batch_idx)[source]#