cyto_dl.models.basic_model module#
- class cyto_dl.models.basic_model.BasicModel(*args, **kwargs)[source]#
Bases:
BaseModel
A minimal Pytorch Lightning wrapper around generic Pytorch models.
- Parameters:
network (Optional[nn.Module] = None) – The network to wrap Assumes that the network outputs gt and predictions
loss (Optional[Loss] = None) – The loss function to optimize for
x_label (str = “x”) – The key used to retrieve the input from dataloader batches
optimizer (torch.optim.Optimizer = torch.optim.Adam) – The optimizer class
save_predictions (Optional[Callable] = None) – A function to save the results of serotiny predict
fields_to_log (Optional[Union[Sequence, Dict]] = None) – List of batch fields to store with the outputs. Use a list to log the same fields for every training stage (train, val, test, prediction). If a list is used, it is assumed to be for test and prediction only
pretrained_weights (Optional[str] = None) – Path to pretrained weights. If network is not None, this will be loaded via network.load_state_dict, otherwise it will be loaded via torch.load.