Source code for cyto_dl.loggers.mlflow

import os
import tempfile
import warnings
from argparse import Namespace
from pathlib import Path
from typing import Any, Dict, Optional, Union

import mlflow
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import MLFlowLogger as _MLFlowLogger
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
from mlflow.store.artifact.local_artifact_repo import LocalArtifactRepository
from mlflow.utils.file_utils import local_file_uri_to_path
from omegaconf import OmegaConf

from cyto_dl import utils

log = utils.get_pylogger(__name__)


[docs]class MLFlowLogger(_MLFlowLogger): def __init__( self, experiment_name: str = "lightning_logs", run_name: Optional[str] = None, tracking_uri: Optional[str] = os.getenv("MLFLOW_TRACKING_URI"), tags: Optional[Dict[str, Any]] = None, save_dir: Optional[str] = "./mlruns", prefix: str = "", artifact_location: Optional[str] = None, run_id: Optional[str] = None, fault_tolerant=True, ): _MLFlowLogger.__init__( self, experiment_name=experiment_name, run_name=run_name, tracking_uri=tracking_uri, tags=tags, save_dir=save_dir, prefix=prefix, artifact_location=artifact_location, run_id=run_id, ) self.fault_tolerant = fault_tolerant if tracking_uri is not None: mlflow.set_tracking_uri(tracking_uri)
[docs] @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], mode="train") -> None: requirements = params.pop("requirements", []) with tempfile.TemporaryDirectory() as tmp_dir: conf_path = Path(tmp_dir) / f"{mode}.yaml" with conf_path.open("w") as f: config = OmegaConf.create(params) OmegaConf.save(config=config, f=f) reqs_path = Path(tmp_dir) / f"{mode}-requirements.txt" reqs_path.write_text("\n".join(requirements)) self.experiment.log_artifact(self.run_id, local_path=conf_path, artifact_path="config") self.experiment.log_artifact( self.run_id, local_path=reqs_path, artifact_path="requirements" )
[docs] @rank_zero_only def log_metrics(self, metrics, step): try: super().log_metrics(metrics, step) except Exception as e: if self.fault_tolerant: log.warn(f"`MLFlowLogger.log_metrics` failed with exception {e}\n\nContinuing...") else: raise e
[docs] def after_save_checkpoint(self, ckpt_callback): try: self._after_save_checkpoint(ckpt_callback) except Exception as e: if self.fault_tolerant: log.warn( f"`MLFlowLogger.after_save_checkpoint` failed with exception {e}\n\nContinuing..." ) else: raise e
def _after_save_checkpoint(self, ckpt_callback: ModelCheckpoint) -> None: """Called after model checkpoint callback saves a new checkpoint.""" monitor = ckpt_callback.monitor if monitor is not None: artifact_path = f"checkpoints/{monitor}" existing_ckpts = { _.path.split("/")[-1] for _ in self.experiment.list_artifacts(self.run_id, path=artifact_path) } top_k_ckpts = {_.split("/")[-1] for _ in ckpt_callback.best_k_models.keys()} to_delete = existing_ckpts - top_k_ckpts to_upload = top_k_ckpts - existing_ckpts run = self.experiment.get_run(self.run_id) repository = get_artifact_repository(run.info.artifact_uri) for ckpt in to_delete: if isinstance(repository, LocalArtifactRepository): _delete_local_artifact(repository, f"checkpoints/{monitor}/{ckpt}") elif hasattr(repository, "delete_artifacts"): repository.delete_artifacts(f"checkpoints/{monitor}/{ckpt}") else: warnings.warn( "The artifact repository configured for this " "MLFlow server doesn't support deleting artifacts, " "so we're keeping all checkpoints." ) for ckpt in to_upload: self.experiment.log_artifact( self.run_id, local_path=os.path.join(ckpt_callback.dirpath, ckpt), artifact_path=artifact_path, ) # also save the current best model as "best.ckpt" filepath = ckpt_callback.best_model_path best_path = Path(filepath).with_name("best.ckpt") os.link(filepath, best_path) self.experiment.log_artifact( self.run_id, local_path=best_path, artifact_path=artifact_path ) best_path.unlink() else: filepath = ckpt_callback.best_model_path artifact_path = "checkpoints" # mimic ModelCheckpoint's behavior: if `self.save_top_k == 1` only # keep the latest checkpoint, otherwise keep all of them. if ckpt_callback.save_top_k == 1: last_path = Path(filepath).with_name("last.ckpt") os.link(filepath, last_path) self.experiment.log_artifact( self.run_id, local_path=last_path, artifact_path=artifact_path ) last_path.unlink() else: self.experiment.log_artifact( self.run_id, local_path=filepath, artifact_path=artifact_path )
def _delete_local_artifact(repo, artifact_path): artifact_path = Path( local_file_uri_to_path( os.path.join(repo._artifact_dir, artifact_path) if artifact_path else repo._artifact_dir ) ) if artifact_path.is_file(): artifact_path.unlink()