Source code for cyto_dl.models.basic_model

from typing import Callable, Optional, Sequence, Union

import numpy as np
import torch
import torch.nn as nn
from omegaconf import DictConfig, ListConfig
from torch.nn.modules.loss import _Loss as Loss
from torchmetrics import MeanMetric

from .base_model import BaseModel

Array = Union[torch.Tensor, np.ndarray, Sequence[float]]


[docs]class BasicModel(BaseModel): """A minimal Pytorch Lightning wrapper around generic Pytorch models.""" def __init__( self, network: Optional[nn.Module] = None, loss: Optional[Loss] = None, x_label: str = "x", optimizer: torch.optim.Optimizer = torch.optim.Adam, save_predictions: Optional[Callable] = None, fields_to_log: Optional[Sequence] = None, pretrained_weights: Optional[str] = None, **kwargs, ): """ Parameters ---------- network: Optional[nn.Module] = None The network to wrap Assumes that the network outputs gt and predictions loss: Optional[Loss] = None The loss function to optimize for x_label: str = "x" The key used to retrieve the input from dataloader batches optimizer: torch.optim.Optimizer = torch.optim.Adam The optimizer class save_predictions: Optional[Callable] = None A function to save the results of `serotiny predict` fields_to_log: Optional[Union[Sequence, Dict]] = None List of batch fields to store with the outputs. Use a list to log the same fields for every training stage (train, val, test, prediction). If a list is used, it is assumed to be for test and prediction only pretrained_weights: Optional[str] = None Path to pretrained weights. If network is not None, this will be loaded via `network.load_state_dict`, otherwise it will be loaded via `torch.load`. """ _DEFAULT_METRICS = { "train/loss": MeanMetric(), "val/loss": MeanMetric(), "test/loss": MeanMetric(), } metrics = kwargs.pop("metrics", _DEFAULT_METRICS) super().__init__(metrics=metrics) if network is None and pretrained_weights is None: raise ValueError("`network` and `pretrained_weights` can't both be None.") if pretrained_weights is not None: pretrained_weights = torch.load(pretrained_weights) if network is not None: self.network = network if pretrained_weights is not None: self.network.load_state_dict(pretrained_weights) else: self.network = pretrained_weights if loss is not None: self.loss = loss else: self.loss = nn.MSELoss() if save_predictions is not None: self.save_predictions = save_predictions self.fields_to_log = fields_to_log self.x_label = x_label
[docs] def forward(self, x, **kwargs): return self.network(x, **kwargs)
[docs] def model_step(self, stage, batch, batch_idx): rec, gt = self.forward(batch[self.x_label]) loss = self.loss(rec, gt).mean() output = { "loss": loss, } return output, None, None