"""Visualization tools."""

from typing import List, Optional, Union
import logging
import os

import matplotlib
import matplotlib.pyplot as plt
import pandas as pd

logger = logging.getLogger(__name__)"seaborn")
COLORS = matplotlib.rcParams["axes.prop_cycle"].by_key()["color"]

def _plot_df(df, ax, model_label, colors, **kwargs):
    """Plot dataframe columns on axes."""
    for idx_c, col in enumerate(df.columns):
        label = (f"{model_label}:" if model_label is not None else "") + f"{col}"
        key = model_label, "_".join(col.split("_")[:-1])
        if key not in colors:
            colors[key] = COLORS[colors["idx"]]
            colors["idx"] = (colors["idx"] + 1) % len(COLORS)
        color = colors[key]
            df.index.to_numpy(), df[col].to_numpy(), color=color, label=label, **kwargs

[docs]def plot_loss( paths_model: Union[List[str], str], path_save: Optional[str] = None, train: bool = True, val: bool = True, title: Optional[str] = None, ymin: Optional[float] = None, ymax: Optional[float] = None, ) -> None: """Plots model loss curve(s). Parameters ---------- paths_model List of paths to model directories specified as a list or as a string of paths separated by spaces. path_save If not None, specifies where to save figure and figure will not be displayed. train Set to plot training curve. val Set to plot validation curve. title Plot title. ymin Y-axis minimum value. ymax Y-axis maximum value. """ if isinstance(paths_model, str): paths_model = paths_model.split(" ") if path_save is not None: plt.switch_backend("Agg") window_train = 128 window_val = 32 colors = {"idx": 0} # maps model-content to colors; idx is COLORS index fig, ax = plt.subplots() for idx_m, path_model in enumerate(paths_model): name_model = os.path.basename(os.path.normpath(path_model)) model_label = None if len(paths_model) == 1 else name_model path_loss = os.path.join(path_model, "losses.csv") df = pd.read_csv(path_loss, index_col="num_iter") if train: cols_train = [col for col in df.columns if col.lower().endswith("_train")] df_train = df.loc[:, cols_train].dropna(axis=1, thresh=1).dropna() df_train_rmean = df_train.rolling(window=window_train).mean() _plot_df(df_train_rmean, ax, model_label, colors, linestyle="-") if val: cols_val = [col for col in df.columns if col.lower().endswith("_val")] df_val = df.loc[:, cols_val].dropna(axis=1, thresh=1).dropna() df_val_rmean = df_val.rolling(window=window_val).mean() _plot_df(df_val_rmean, ax, model_label, colors, linestyle="--") if title is not None: ax.set_title(title) ax.set_ylim([ymin, ymax]) ax.set_xlabel("Training iterations") ax.set_ylabel("Rolling mean squared error") ax.legend() if path_save is not None: fig.savefig(path_save, bbox_inches="tight")"Saved: {path_save}") return
[docs]def plot_metric( path_csv: str, metric: str, path_save: Optional[str] = None, title: Optional[str] = None, ymin: Optional[float] = None, ymax: Optional[float] = None, ) -> None: """Plots box-plot of model performance according to some metric. Parameters ---------- path_csv Path to csv where each row is a dataset item. metric Name of metric. Should be within one or more CSV column names. path_save If not None, specifies where to save figure and figure will not be displayed. title Plot title. ymin Y-axis minimum value. ymax Y-axis maximum value. """ if path_save is not None: plt.switch_backend("Agg") df = pd.read_csv(path_csv) cols = [c for c in df.columns if metric in c] cols_rename = {c: c.split(metric)[-1] for c in cols} df = df.loc[:, cols].rename(columns=cols_rename) fig, ax = plt.subplots() df.boxplot(ax=ax) if title is not None: ax.set_title(title) ax.set_ylim([ymin, ymax]) ax.set_ylabel("Pearson correlation coefficient (r)") if path_save is not None: fig.savefig(path_save, bbox_inches="tight")"Saved: {path_save}") return