Source code for cyto_dl.eval

import os
from collections.abc import MutableMapping
from contextlib import suppress
from typing import List, Tuple

import hydra
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from monai.data import DataLoader as MonaiDataLoader
from omegaconf import DictConfig, ListConfig, OmegaConf
from torch.utils.data import DataLoader as TorchDataLoader

from cyto_dl import utils

log = utils.get_pylogger(__name__)

with suppress(ValueError):
    OmegaConf.register_new_resolver("kv_to_dict", utils.kv_to_dict)
    OmegaConf.register_new_resolver("eval", eval)


@utils.task_wrapper
def evaluate(cfg: DictConfig, data=None) -> Tuple[dict, dict, dict]:
    """Evaluates given checkpoint on a datamodule testset.

    This method is wrapped in optional @task_wrapper decorator which applies extra utilities
    before and after the call.

    Args:
        cfg (DictConfig): Configuration composed by Hydra.

    Returns:
        Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
    """

    if not cfg.checkpoint.ckpt_path:
        raise ValueError("Checkpoint path must be included for testing")

    # resolve config to avoid unresolvable interpolations in the stored config
    OmegaConf.resolve(cfg)

    # remove aux section after resolving and before instantiating
    utils.remove_aux_key(cfg)
    data = utils.create_dataloader(cfg.data, data)
    if not isinstance(data, (LightningDataModule, TorchDataLoader, MonaiDataLoader)):
        if isinstance(data, MutableMapping) and not data.dataloaders:
            raise ValueError(
                "If the `data` config for eval/prediction is a dict it must have a "
                "key `dataloaders` with the corresponding value being a DataLoader "
                "(or list thereof)."
            )
        elif not isinstance(data, (list, ListConfig)):
            raise ValueError(
                "`data` config for eval/prediction must be either:\n"
                " - a LightningDataModule\n"
                " - a DataLoader (or list thereof)\n"
                " - a dict with key `dataloaders`, with the corresponding value "
                "being a DataLoader (or list thereof)"
            )

    log.info(f"Instantiating model <{cfg.model._target_}>")
    model: LightningModule = hydra.utils.instantiate(cfg.model, _recursive_=False)

    log.info("Instantiating loggers...")
    logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))

    log.info("Instantiating callbacks...")
    callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))

    log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
    trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger, callbacks=callbacks)

    object_dict = {
        "cfg": cfg,
        "data": data,
        "model": model,
        "logger": logger,
        "trainer": trainer,
    }

    if logger:
        log.info("Logging hyperparameters!")
        utils.log_hyperparameters(object_dict)

    model, load_params = utils.load_checkpoint(model, cfg.get("checkpoint"))

    log.info("Starting testing!")
    method = trainer.test if cfg.get("test", False) else trainer.predict
    output = method(model=model, dataloaders=data, ckpt_path=load_params.ckpt_path)
    metric_dict = trainer.callback_metrics

    return metric_dict, object_dict, output


[docs]@hydra.main( version_base="1.3", config_path=os.environ.get("CYTODL_CONFIG_PATH", "../configs"), config_name="eval.yaml", ) def main(cfg: DictConfig) -> None: evaluate(cfg)
if __name__ == "__main__": main()