import subprocess # nosec: B404
import sys
import time
import warnings
from importlib.util import find_spec
from pathlib import Path
from typing import Any, Callable, List
import hydra
from lightning import Callback
from lightning.pytorch.loggers import Logger
from lightning.pytorch.utilities import rank_zero_only
from omegaconf import DictConfig
from cyto_dl.loggers import MLFlowLogger
from . import pylogger, rich_utils
log = pylogger.get_pylogger(__name__)
__all__ = [
"task_wrapper",
"extras",
"save_file",
"instantiate_callbacks",
"instantiate_loggers",
"log_hyperparameters",
"get_metric_value",
"close_loggers",
]
[docs]def task_wrapper(task_func: Callable) -> Callable:
"""Optional decorator that wraps the task function in extra utilities.
Makes multirun more resistant to failure.
Utilities:
- Calling the `utils.extras()` before the task is started
- Calling the `utils.close_loggers()` after the task is finished
- Logging the exception if occurs
- Logging the task total execution time
- Logging the output dir
"""
def wrap(cfg: DictConfig, data: Any = None):
# apply extra utilities
extras(cfg)
# execute the task
try:
start_time = time.time()
out = task_func(cfg=cfg, data=data)
except Exception as ex:
log.exception("") # save exception to `.log` file
raise ex
finally:
path = Path(cfg.paths.output_dir, "exec_time.log")
content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)"
save_file(path, content) # save task execution time (even if exception occurs)
close_loggers() # close loggers (even if exception occurs so multirun won't fail)
log.info(f"Output dir: {cfg.paths.output_dir}")
return out
return wrap
[docs]@rank_zero_only
def save_file(path: str, content: str) -> None:
"""Save file in rank zero mode (only on one process in multi-GPU setup)."""
with open(path, "w+") as file: # noqa: FURB103
file.write(content)
[docs]def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
"""Instantiates callbacks from config."""
callbacks: List[Callback] = []
if not callbacks_cfg:
log.warning("Callbacks config is empty.")
return callbacks
if not isinstance(callbacks_cfg, DictConfig):
raise TypeError("Callbacks config must be a DictConfig!")
for _, cb_conf in callbacks_cfg.items():
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))
return callbacks
[docs]def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
"""Instantiates loggers from config."""
logger: List[Logger] = []
if not logger_cfg:
log.warning("Logger config is empty.")
return logger
if not isinstance(logger_cfg, DictConfig):
raise TypeError("Logger config must be a DictConfig!")
for _, lg_conf in logger_cfg.items():
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>")
logger.append(hydra.utils.instantiate(lg_conf))
return logger
[docs]@rank_zero_only
def log_hyperparameters(object_dict: dict) -> None:
"""Controls which config parts are saved by lightning loggers.
Additionally saves:
- Number of model parameters
"""
hparams = {}
cfg = object_dict["cfg"]
model = object_dict["model"]
trainer = object_dict["trainer"]
if not trainer.logger:
log.warning("Logger not found! Skipping hyperparameter logging...")
return
hparams["model"] = cfg.get("model")
# save number of model parameters
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
hparams["model/params/trainable"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
hparams["model/params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)
hparams["data"] = cfg["data"]
hparams["trainer"] = cfg["trainer"]
hparams["callbacks"] = cfg.get("callbacks")
hparams["extras"] = cfg.get("extras")
hparams["task_name"] = cfg.get("task_name")
hparams["tags"] = cfg.get("tags")
hparams["ckpt_path"] = cfg.checkpoint.get("ckpt_path")
hparams["seed"] = cfg.get("seed")
try:
reqs = subprocess.check_output([sys.executable, "-m", "pip", "freeze"]) # nosec: B603
hparams["requirements"] = str(reqs).split("\\n")
except subprocess.CalledProcessError:
# not mandatory to save requirements; allows segmenter plugin devs to use PDM
pass
# send hparams to all loggers
for logger in trainer.loggers:
if isinstance(logger, MLFlowLogger):
logger.log_hyperparams(hparams, mode=cfg.task_name)
else:
logger.log_hyperparams(hparams)
[docs]def get_metric_value(metric_dict: dict, metric_name: str) -> float:
"""Safely retrieves value of the metric logged in LightningModule."""
if not metric_name:
log.info("Metric name is None! Skipping metric value retrieval...")
return None
if metric_name not in metric_dict:
raise Exception(
f"Metric value not found! <metric_name={metric_name}>\n"
"Make sure metric name logged in LightningModule is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
)
metric_value = metric_dict[metric_name].item()
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
return metric_value
[docs]def close_loggers() -> None:
"""Makes sure all loggers closed properly (prevents logging failure during multirun)."""
log.info("Closing loggers...")
if find_spec("wandb"): # if wandb is installed
import wandb
if wandb.run:
log.info("Closing wandb!")
wandb.finish()