cyto_dl.models.classification.classification module#

class cyto_dl.models.classification.classification.Classification(*args, **kwargs)[source]#

Bases: BaseModel

Parameters:
  • model (nn.Module) – model network, parameters are shared between task heads

  • x_key (str) – key of input image in batch

  • save_dir=”./” – directory to save images during training and validation

  • save_images_every_n_epochs=1 – Frequency to save out images during training

  • compile (False) – Whether to compile the model using torch.compile

  • **base_kwargs – Additional arguments passed to BaseModel

forward(x)[source]#
model_step(stage, batch, batch_idx)[source]#
predict_step(batch, batch_idx)[source]#
save_images(batch, stage, logits, name=None)[source]#

Create image with prediction and label text overlaid on each image in batch.

should_save_image(batch_idx, stage)[source]#