Source code for cyto_dl.utils.rich_utils

from pathlib import Path
from typing import Sequence

import rich
import rich.syntax
import rich.tree
from hydra.core.hydra_config import HydraConfig
from lightning.pytorch.utilities import rank_zero_only
from omegaconf import DictConfig, OmegaConf, open_dict
from rich.prompt import Prompt

from cyto_dl.utils import pylogger

log = pylogger.get_pylogger(__name__)





[docs]@rank_zero_only def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: """Prompts user to input tags from command line if no tags are provided in config.""" if not cfg.get("tags"): if "id" in HydraConfig().cfg.hydra.job: raise ValueError("Specify tags before launching a multirun!") log.warning("No tags provided in config. Prompting user to input tags...") tags = Prompt.ask("Enter a list of comma separated tags", default="dev") tags = [t.strip() for t in tags.split(",") if t != ""] with open_dict(cfg): cfg.tags = tags log.info(f"Tags: {cfg.tags}") if save_to_file: with Path(cfg.paths.output_dir, "tags.log").open("w") as file: rich.print(cfg.tags, file=file)
if __name__ == "__main__": from hydra import compose, initialize with initialize(version_base="1.2", config_path="../../configs"): cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[]) print_config_tree(cfg, resolve=False, save_to_file=False)