Source code for cyto_dl.api.cyto_dl_model.cyto_dl_base_model

from abc import ABC, abstractmethod
from copy import deepcopy
from pathlib import Path
from typing import Any, List, Optional, Union

import pyrootutils
from hydra import compose, initialize_config_dir
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig, OmegaConf, open_dict

from cyto_dl.api.data import ExperimentType
from cyto_dl.eval import evaluate as evaluate_model
from cyto_dl.train import train as train_model

# TODO: encapsulate experiment management (file system) details here, will require passing output_dir
# into the factory methods, maybe


[docs]class CytoDLBaseModel(ABC): """A CytoDLBaseModel is used to configure, train, and run predictions on a cyto-dl model.""" def __init__(self, cfg: DictConfig): """Not intended for direct use by clients. Please see the classmethod factory methods instead. """ self._cfg: DictConfig = cfg @classmethod @abstractmethod def _get_experiment_type(cls) -> ExperimentType: """Return experiment type for this config (e.g. segmentation_plugin, gan, etc)""" pass
[docs] @classmethod def from_existing_config(cls, config_filepath: Path): """Returns a model from an existing config. :param config_filepath: path to a .yaml config file that will be used as the basis for this CytoDLBaseModel (must be generated by the CytoDLBaseModel subclass that wants to use it). """ return cls(OmegaConf.load(config_filepath))
# TODO: if spatial_dims is only ever 2 or 3, create an enum for it
[docs] @classmethod def from_default_config(cls, spatial_dims: int): """Returns a model from the default config. :param spatial_dims: dimensions for the model (e.g. 2) """ cfg_dir: Path = ( pyrootutils.find_root(search_from=__file__, indicator=("pyproject.toml", "README.md")) / "configs" ) GlobalHydra.instance().clear() with initialize_config_dir(version_base="1.2", config_dir=str(cfg_dir)): cfg: DictConfig = compose( config_name="train.yaml", # train.yaml can work for prediction too return_hydra_config=True, overrides=[ f"experiment=im2im/{cls._get_experiment_type().name.lower()}", f"spatial_dims={spatial_dims}", ], ) with open_dict(cfg): del cfg["hydra"] cfg.extras.enforce_tags = False cfg.extras.print_config = False return cls(cfg)
@abstractmethod def _set_max_epochs(self, max_epochs: int) -> None: pass @abstractmethod def _set_manifest_path(self, manifest_path: Union[str, Path]) -> None: pass @abstractmethod def _set_output_dir(self, output_dir: Union[str, Path]) -> None: pass def _key_exists(self, k: str) -> bool: keys: List[str] = k.split(".") curr_dict: DictConfig = self._cfg while keys: key: str = keys.pop(0) if key not in curr_dict: return False curr_dict = curr_dict[key] return True def _set_cfg(self, k: str, v: Any) -> None: if not self._key_exists(k): raise KeyError(f"{k} not found in config dict") OmegaConf.update(self._cfg, k, v) def _get_cfg(self, k: str) -> Any: if not self._key_exists(k): raise KeyError(f"{k} not found in config dict") return OmegaConf.select(self._cfg, k) def _set_training_config(self, train: bool): self._set_cfg("train", train) self._set_cfg("test", train) # afaik, task_name isn't used outside of template_utils.py - do we need to support this? self._set_cfg("task_name", "train" if train else "predict") def _set_ckpt(self, ckpt: Optional[Path]) -> None: self._set_cfg("checkpoint.ckpt_path", str(ckpt.resolve()) if ckpt else ckpt) # does experiment name have any effect?
[docs] def set_experiment_name(self, name: str) -> None: self._set_cfg("experiment_name", name)
[docs] def get_experiment_name(self) -> str: return self._get_cfg("experiment_name")
[docs] def get_config(self) -> DictConfig: return deepcopy(self._cfg)
[docs] def save_config(self, path: Path) -> None: OmegaConf.save(self._cfg, path)
[docs] def train( self, max_epochs: int, manifest_path: Union[str, Path], output_dir: Union[str, Path], checkpoint: Optional[Path] = None, ) -> None: self._set_training_config(True) self._set_max_epochs(max_epochs) self._set_manifest_path(manifest_path) self._set_output_dir(output_dir) self._set_ckpt(checkpoint) train_model(self._cfg)
[docs] def predict( self, manifest_path: Union[str, Path], output_dir: Union[str, Path], checkpoint: Path ) -> None: self._set_training_config(False) self._set_manifest_path(manifest_path) self._set_output_dir(output_dir) self._set_ckpt(checkpoint) evaluate_model(self._cfg)