Source code for cyto_dl.utils.template_utils

import subprocess  # nosec: B404
import sys
import time
import warnings
from importlib.util import find_spec
from pathlib import Path
from typing import Any, Callable, List

import hydra
from lightning import Callback
from lightning.pytorch.loggers import Logger
from lightning.pytorch.utilities import rank_zero_only
from omegaconf import DictConfig

from cyto_dl.loggers import MLFlowLogger

from . import pylogger, rich_utils

log = pylogger.get_pylogger(__name__)

__all__ = [
    "task_wrapper",
    "extras",
    "save_file",
    "instantiate_callbacks",
    "instantiate_loggers",
    "log_hyperparameters",
    "get_metric_value",
    "close_loggers",
]


[docs]def task_wrapper(task_func: Callable) -> Callable: """Optional decorator that wraps the task function in extra utilities. Makes multirun more resistant to failure. Utilities: - Calling the `utils.extras()` before the task is started - Calling the `utils.close_loggers()` after the task is finished - Logging the exception if occurs - Logging the task total execution time - Logging the output dir """ def wrap(cfg: DictConfig, data: Any = None): # apply extra utilities extras(cfg) # execute the task try: start_time = time.time() out = task_func(cfg=cfg, data=data) except Exception as ex: log.exception("") # save exception to `.log` file raise ex finally: path = Path(cfg.paths.output_dir, "exec_time.log") content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)" save_file(path, content) # save task execution time (even if exception occurs) close_loggers() # close loggers (even if exception occurs so multirun won't fail) log.info(f"Output dir: {cfg.paths.output_dir}") return out return wrap
[docs]def extras(cfg: DictConfig) -> None: """Applies optional utilities before the task is started. Utilities: - Ignoring python warnings - Setting tags from command line - Rich config printing """ # return if no `extras` config if not cfg.get("extras"): log.warning("Extras config not found! <cfg.extras=null>") return # disable python warnings if cfg.extras.get("ignore_warnings"): log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>") warnings.filterwarnings("ignore") # prompt user to input tags from command line if none are provided in the config if cfg.extras.get("enforce_tags"): log.info("Enforcing tags! <cfg.extras.enforce_tags=True>") rich_utils.enforce_tags(cfg, save_to_file=True) # pretty print config tree using Rich library if cfg.extras.get("print_config"): log.info("Printing config tree with Rich! <cfg.extras.print_config=True>") rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) if cfg.extras.get("precision"): hydra.utils.instantiate(cfg.extras)
[docs]@rank_zero_only def save_file(path: str, content: str) -> None: """Save file in rank zero mode (only on one process in multi-GPU setup).""" with open(path, "w+") as file: # noqa: FURB103 file.write(content)
[docs]def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: """Instantiates callbacks from config.""" callbacks: List[Callback] = [] if not callbacks_cfg: log.warning("Callbacks config is empty.") return callbacks if not isinstance(callbacks_cfg, DictConfig): raise TypeError("Callbacks config must be a DictConfig!") for _, cb_conf in callbacks_cfg.items(): if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: log.info(f"Instantiating callback <{cb_conf._target_}>") callbacks.append(hydra.utils.instantiate(cb_conf)) return callbacks
[docs]def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: """Instantiates loggers from config.""" logger: List[Logger] = [] if not logger_cfg: log.warning("Logger config is empty.") return logger if not isinstance(logger_cfg, DictConfig): raise TypeError("Logger config must be a DictConfig!") for _, lg_conf in logger_cfg.items(): if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: log.info(f"Instantiating logger <{lg_conf._target_}>") logger.append(hydra.utils.instantiate(lg_conf)) return logger
[docs]@rank_zero_only def log_hyperparameters(object_dict: dict) -> None: """Controls which config parts are saved by lightning loggers. Additionally saves: - Number of model parameters """ hparams = {} cfg = object_dict["cfg"] model = object_dict["model"] trainer = object_dict["trainer"] if not trainer.logger: log.warning("Logger not found! Skipping hyperparameter logging...") return hparams["model"] = cfg.get("model") # save number of model parameters hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) hparams["model/params/trainable"] = sum( p.numel() for p in model.parameters() if p.requires_grad ) hparams["model/params/non_trainable"] = sum( p.numel() for p in model.parameters() if not p.requires_grad ) hparams["data"] = cfg["data"] hparams["trainer"] = cfg["trainer"] hparams["callbacks"] = cfg.get("callbacks") hparams["extras"] = cfg.get("extras") hparams["task_name"] = cfg.get("task_name") hparams["tags"] = cfg.get("tags") hparams["ckpt_path"] = cfg.checkpoint.get("ckpt_path") hparams["seed"] = cfg.get("seed") try: reqs = subprocess.check_output([sys.executable, "-m", "pip", "freeze"]) # nosec: B603 hparams["requirements"] = str(reqs).split("\\n") except subprocess.CalledProcessError: # not mandatory to save requirements; allows segmenter plugin devs to use PDM pass # send hparams to all loggers for logger in trainer.loggers: if isinstance(logger, MLFlowLogger): logger.log_hyperparams(hparams, mode=cfg.task_name) else: logger.log_hyperparams(hparams)
[docs]def get_metric_value(metric_dict: dict, metric_name: str) -> float: """Safely retrieves value of the metric logged in LightningModule.""" if not metric_name: log.info("Metric name is None! Skipping metric value retrieval...") return None if metric_name not in metric_dict: raise Exception( f"Metric value not found! <metric_name={metric_name}>\n" "Make sure metric name logged in LightningModule is correct!\n" "Make sure `optimized_metric` name in `hparams_search` config is correct!" ) metric_value = metric_dict[metric_name].item() log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") return metric_value
[docs]def close_loggers() -> None: """Makes sure all loggers closed properly (prevents logging failure during multirun).""" log.info("Closing loggers...") if find_spec("wandb"): # if wandb is installed import wandb if wandb.run: log.info("Closing wandb!") wandb.finish()