cyto_dl.models.classification.classification module#
- class cyto_dl.models.classification.classification.Classification(*args, **kwargs)[source]#
Bases:
BaseModel
- Parameters:
model (nn.Module) – model network, parameters are shared between task heads
x_key (str) – key of input image in batch
save_dir=”./” – directory to save images during training and validation
save_images_every_n_epochs=1 – Frequency to save out images during training
compile (False) – Whether to compile the model using torch.compile
**base_kwargs – Additional arguments passed to BaseModel