cyto_dl.models.base_model module#
- class cyto_dl.models.base_model.BaseModel(*args, **kwargs)[source]#
Bases:
LightningModule
- compute_metrics(loss, preds, targets, split)[source]#
Method to handle logging metrics. Assumptions made:
the _step method of each model returns a tuple (loss, preds, targets), whose elements may be dictionaries
the keys of self.metrics have a specific structure: ‘split/type(/part)’ , where:
split is one of (train, val, test, predict)
type is either “loss”, or an arbitrary string denoting a metric
part is optional, used when (loss, preds, targets) are dictionaries, in which case it must match a dictionary key
- model_step(stage, batch, batch_idx)[source]#
Here you should implement the logic for a step in the training/validation/test process. The stage (training/validation/test) is given by the variable stage.
Example:
x = self.parse_batch(batch)
- if self.hparams.id_label is not None:
- if self.hparams.id_label in batch:
ids = batch[self.hparams.id_label].detach().cpu() results.update({
self.hparams.id_label: ids, “id”: ids
})
return results