cyto_dl.models.basic_model module#

class cyto_dl.models.basic_model.BasicModel(*args, **kwargs)[source]#

Bases: BaseModel

A minimal Pytorch Lightning wrapper around generic Pytorch models.

Parameters:
  • network (Optional[nn.Module] = None) – The network to wrap Assumes that the network outputs gt and predictions

  • loss (Optional[Loss] = None) – The loss function to optimize for

  • x_label (str = “x”) – The key used to retrieve the input from dataloader batches

  • optimizer (torch.optim.Optimizer = torch.optim.Adam) – The optimizer class

  • save_predictions (Optional[Callable] = None) – A function to save the results of serotiny predict

  • fields_to_log (Optional[Union[Sequence, Dict]] = None) – List of batch fields to store with the outputs. Use a list to log the same fields for every training stage (train, val, test, prediction). If a list is used, it is assumed to be for test and prediction only

  • pretrained_weights (Optional[str] = None) – Path to pretrained weights. If network is not None, this will be loaded via network.load_state_dict, otherwise it will be loaded via torch.load.

forward(x, **kwargs)[source]#
model_step(stage, batch, batch_idx)[source]#