Source code for cyto_dl.models.utils.mlflow

import logging
import tempfile

import mlflow
from hydra._internal.utils import _locate
from hydra.utils import instantiate
from omegaconf import OmegaConf

logger = logging.getLogger(__name__)


[docs]def get_config(tracking_uri, run_id, tmp_dir, mode="train"): artifact_path = f"config/{mode}.yaml" mlflow.set_tracking_uri(tracking_uri) config = mlflow.artifacts.download_artifacts( run_id=run_id, artifact_path=artifact_path, dst_path=tmp_dir, ) config = OmegaConf.load(config) config = OmegaConf.to_container(config, resolve=True) return config
[docs]def load_model_from_checkpoint(tracking_uri, run_id, strict=True, path="checkpoints/last.ckpt"): mlflow.set_tracking_uri(tracking_uri) with tempfile.TemporaryDirectory() as tmp_dir: ckpt_path = mlflow.artifacts.download_artifacts( run_id=run_id, artifact_path=path, dst_path=tmp_dir ) config = get_config(tracking_uri, run_id, tmp_dir, mode="train") model_conf = config["model"] model_class = model_conf.pop("_target_") model_conf = instantiate(model_conf) model_class = _locate(model_class) return model_class.load_from_checkpoint(ckpt_path, **model_conf, strict=strict).eval()