import copy
import inspect
import logging
from import MutableMapping
from typing import Optional, Sequence, Union
import numpy as np
import torch
from hydra.utils import instantiate
from lightning import LightningModule
from omegaconf import DictConfig, ListConfig, OmegaConf
from torchmetrics import MeanMetric
Array = Union[torch.Tensor, np.ndarray, Sequence[float]]
logger = logging.getLogger("lightning")
logger.propagate = False
"train/loss": MeanMetric(),
"val/loss": MeanMetric(),
"test/loss": MeanMetric(),
def _is_primitive(value):
if value is None or isinstance(value, (bool, str, int, float)):
return True
if isinstance(value, (tuple, list)):
return all(_is_primitive(el) for el in value)
if isinstance(value, dict):
return all(_is_primitive(el) for el in value.values())
return False
def _cast_init_arg(value):
if isinstance(value, inspect.Parameter):
return value._default
if isinstance(value, (ListConfig, DictConfig)):
return OmegaConf.to_container(value, resolve=True)
return value
[docs]class BaseModel(LightningModule, metaclass=BaseModelMeta):
def __init__(
optimizer: Optional[torch.optim.Optimizer] = None,
lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
self.metrics = tuple(metrics.keys())
for key, value in metrics.items():
setattr(self, key, value)
self.optimizer = optimizer if optimizer is not None else torch.optim.Adam
self.lr_scheduler = lr_scheduler
[docs] def parse_batch(self, batch):
raise NotImplementedError
[docs] def forward(self, x, **kwargs):
raise NotImplementedError
[docs] def compute_metrics(self, loss, preds, targets, split):
"""Method to handle logging metrics. Assumptions made:
- the `_step` method of each model returns a tuple (loss, preds, targets),
whose elements may be dictionaries
- the keys of `self.metrics` have a specific structure:
'split/type(/part)' , where:
- `split` is one of (train, val, test, predict)
- `type` is either "loss", or an arbitrary string denoting a metric
- `part` is optional, used when (loss, preds, targets) are dictionaries,
in which case it must match a dictionary key
for metric_key in self.metrics:
metric_split, metric_type, *metric_part = metric_key.split("/")
if not metric_split.startswith(split):
if len(metric_part) > 0:
metric_part = "/".join(metric_part)
metric_part = None
metric = getattr(self, metric_key)
if metric_type == "loss":
if metric_part is not None:
if not isinstance(loss, MutableMapping):
elif "loss" in loss:
raise TypeError(
"Expected `loss` to be a single value or tensor, "
"or a dictionary with a key 'loss', but it isn't."
if metric_part is not None:
metric.update(preds[metric_part], targets[metric_part])
if not isinstance(preds, MutableMapping):
metric.update(preds, targets)
self.log(metric_key, metric, on_step=True, on_epoch=True, prog_bar=True)
[docs] def model_step(self, stage, batch, batch_idx):
"""Here you should implement the logic for a step in the training/validation/test process.
The stage (training/validation/test) is given by the variable `stage`.
x = self.parse_batch(batch)
if self.hparams.id_label is not None:
if self.hparams.id_label in batch:
ids = batch[self.hparams.id_label].detach().cpu()
self.hparams.id_label: ids,
"id": ids
return results
raise NotImplementedError
[docs] def on_train_start(self):
for metric_key in self.metrics:
metric_split, *_ = metric_key.split("/")
if metric_split.startswith("val"):
metric = getattr(self, metric_key)
[docs] def training_step(self, batch, batch_idx):
loss, preds, targets = self.model_step("train", batch, batch_idx)
self.compute_metrics(loss, preds, targets, "train")
return loss
[docs] def validation_step(self, batch, batch_idx):
loss, preds, targets = self.model_step("val", batch, batch_idx)
self.compute_metrics(loss, preds, targets, "val")
return loss
[docs] def test_step(self, batch, batch_idx):
loss, preds, targets = self.model_step("test", batch, batch_idx)
self.compute_metrics(loss, preds, targets, "test")
return loss
[docs] def predict_step(self, batch, batch_idx):
"""Here you should implement the logic for an inference step.
In most cases this would simply consist of calling the forward pass of your model, but you
might wish to add additional post-processing.
raise NotImplementedError