cyto_dl.models.base_model module#

class cyto_dl.models.base_model.BaseModel(*args, **kwargs)[source]#

Bases: LightningModule

compute_metrics(loss, preds, targets, split)[source]#

Method to handle logging metrics. Assumptions made:

  • the _step method of each model returns a tuple (loss, preds, targets), whose elements may be dictionaries

  • the keys of self.metrics have a specific structure: ‘split/type(/part)’ , where:

    • split is one of (train, val, test, predict)

    • type is either “loss”, or an arbitrary string denoting a metric

    • part is optional, used when (loss, preds, targets) are dictionaries, in which case it must match a dictionary key

configure_optimizers()[source]#
forward(x, **kwargs)[source]#
model_step(stage, batch, batch_idx)[source]#

Here you should implement the logic for a step in the training/validation/test process. The stage (training/validation/test) is given by the variable stage.

Example:

x = self.parse_batch(batch)

if self.hparams.id_label is not None:
if self.hparams.id_label in batch:

ids = batch[self.hparams.id_label].detach().cpu() results.update({

self.hparams.id_label: ids, “id”: ids

})

return results

on_train_start()[source]#
parse_batch(batch)[source]#
predict_step(batch, batch_idx)[source]#

Here you should implement the logic for an inference step.

In most cases this would simply consist of calling the forward pass of your model, but you might wish to add additional post-processing.

test_step(batch, batch_idx)[source]#
training_step(batch, batch_idx)[source]#
validation_step(batch, batch_idx)[source]#
class cyto_dl.models.base_model.BaseModelMeta[source]#

Bases: type