from abc import ABC, abstractmethod
from copy import deepcopy
from pathlib import Path
from typing import Any, List, Optional, Union
import pyrootutils
from hydra import compose, initialize_config_dir
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig, OmegaConf, open_dict
from cyto_dl.api.data import ExperimentType
from cyto_dl.eval import evaluate as evaluate_model
from cyto_dl.train import train as train_model
# TODO: encapsulate experiment management (file system) details here, will require passing output_dir
# into the factory methods, maybe
[docs]class CytoDLBaseModel(ABC):
"""A CytoDLBaseModel is used to configure, train, and run predictions on a cyto-dl model."""
def __init__(self, cfg: DictConfig):
"""Not intended for direct use by clients.
Please see the classmethod factory methods instead.
"""
self._cfg: DictConfig = cfg
@classmethod
@abstractmethod
def _get_experiment_type(cls) -> ExperimentType:
"""Return experiment type for this config (e.g. segmentation_plugin, gan, etc)"""
pass
[docs] @classmethod
def from_existing_config(cls, config_filepath: Path):
"""Returns a model from an existing config.
:param config_filepath: path to a .yaml config file that will be used as the basis for this
CytoDLBaseModel (must be generated by the CytoDLBaseModel subclass that wants to use
it).
"""
return cls(OmegaConf.load(config_filepath))
# TODO: if spatial_dims is only ever 2 or 3, create an enum for it
[docs] @classmethod
def from_default_config(cls, spatial_dims: int):
"""Returns a model from the default config.
:param spatial_dims: dimensions for the model (e.g. 2)
"""
cfg_dir: Path = (
pyrootutils.find_root(search_from=__file__, indicator=("pyproject.toml", "README.md"))
/ "configs"
)
GlobalHydra.instance().clear()
with initialize_config_dir(version_base="1.2", config_dir=str(cfg_dir)):
cfg: DictConfig = compose(
config_name="train.yaml", # train.yaml can work for prediction too
return_hydra_config=True,
overrides=[
f"experiment=im2im/{cls._get_experiment_type().name.lower()}",
f"spatial_dims={spatial_dims}",
],
)
with open_dict(cfg):
del cfg["hydra"]
cfg.extras.enforce_tags = False
cfg.extras.print_config = False
return cls(cfg)
@abstractmethod
def _set_max_epochs(self, max_epochs: int) -> None:
pass
@abstractmethod
def _set_manifest_path(self, manifest_path: Union[str, Path]) -> None:
pass
@abstractmethod
def _set_output_dir(self, output_dir: Union[str, Path]) -> None:
pass
def _key_exists(self, k: str) -> bool:
keys: List[str] = k.split(".")
curr_dict: DictConfig = self._cfg
while keys:
key: str = keys.pop(0)
if key not in curr_dict:
return False
curr_dict = curr_dict[key]
return True
def _set_cfg(self, k: str, v: Any) -> None:
if not self._key_exists(k):
raise KeyError(f"{k} not found in config dict")
OmegaConf.update(self._cfg, k, v)
def _get_cfg(self, k: str) -> Any:
if not self._key_exists(k):
raise KeyError(f"{k} not found in config dict")
return OmegaConf.select(self._cfg, k)
def _set_training_config(self, train: bool):
self._set_cfg("train", train)
self._set_cfg("test", train)
# afaik, task_name isn't used outside of template_utils.py - do we need to support this?
self._set_cfg("task_name", "train" if train else "predict")
def _set_ckpt(self, ckpt: Optional[Path]) -> None:
self._set_cfg("checkpoint.ckpt_path", str(ckpt.resolve()) if ckpt else ckpt)
# does experiment name have any effect?
[docs] def set_experiment_name(self, name: str) -> None:
self._set_cfg("experiment_name", name)
[docs] def get_experiment_name(self) -> str:
return self._get_cfg("experiment_name")
[docs] def get_config(self) -> DictConfig:
return deepcopy(self._cfg)
[docs] def save_config(self, path: Path) -> None:
OmegaConf.save(self._cfg, path)
[docs] def train(
self,
max_epochs: int,
manifest_path: Union[str, Path],
output_dir: Union[str, Path],
checkpoint: Optional[Path] = None,
) -> None:
self._set_training_config(True)
self._set_max_epochs(max_epochs)
self._set_manifest_path(manifest_path)
self._set_output_dir(output_dir)
self._set_ckpt(checkpoint)
train_model(self._cfg)
[docs] def predict(
self, manifest_path: Union[str, Path], output_dir: Union[str, Path], checkpoint: Path
) -> None:
self._set_training_config(False)
self._set_manifest_path(manifest_path)
self._set_output_dir(output_dir)
self._set_ckpt(checkpoint)
evaluate_model(self._cfg)