Source code for cyto_dl.callbacks.model_utils

import logging
import os
import tempfile
from pathlib import Path
from typing import Optional

import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pandas as pd
import torch
from lightning import Callback, LightningModule, Trainer

log = logging.getLogger(__name__)


[docs]def save_predictions_classifier(preds, output_dir): """ TODO: make this better? maybe use vol predictor code? TODO: drop unnecessary index """ records = [] for pred in preds: record = {} for col in ("id", "y", "yhat"): record[col] = pred[col].squeeze().numpy() record["loss"] = [pred["loss"].item()] * len(pred["id"]) records.append(pd.DataFrame(record)) pd.concat(records).reset_index().drop(columns="index").to_csv( Path(output_dir) / "model_predictions.csv", index_label=False )
[docs]class GetEmbeddings(Callback): """""" def __init__( self, x_label: str, id_label: Optional[str] = None, ): """ Args: x_label: x_label from datamodule id_field: id_field from datamodule """ super().__init__() self.x_label = x_label self.id_label = id_label self.cutoff_kld_per_dim = 0
[docs] def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule): with torch.no_grad(): embeddings = get_all_embeddings( trainer.datamodule.train_dataloader(), trainer.datamodule.val_dataloader(), trainer.datamodule.test_dataloader(), pl_module, self.x_label, self.id_label, ) with tempfile.TemporaryDirectory() as tmp_dir: dest_path = os.path.join(tmp_dir, "embeddings.csv") embeddings.to_csv(dest_path) mlflow.log_artifact(local_path=dest_path, artifact_path="dataframes")
[docs]def get_all_embeddings( train_dataloader, val_dataloader, test_dataloader, pl_module: LightningModule, x_label: str, id_label: None, ): all_embeddings = [] cell_ids = [] split = [] zip_iter = zip(["train", "val", "test"], [train_dataloader, val_dataloader, test_dataloader]) with torch.no_grad(): for split_name, dataloader in zip_iter: log.info(f"Getting embeddings for split: {split_name}") _bs = dataloader.batch_size _len = len(dataloader) * dataloader.batch_size _embeddings = np.zeros((_len, pl_module.latent_dim)) _split = np.empty(_len, dtype=object) _ids = None id_label = pl_module.hparams.get("id_label", None) if id_label is None else id_label for index, batch in enumerate(dataloader): if _ids is None: if id_label is not None and id_label in batch: _ids = np.empty(_len, dtype=batch[id_label].cpu().numpy().dtype) else: _ids = None for key in batch.keys(): if not isinstance(batch[key], list): batch[key] = batch[key].to(pl_module.device) z_parts_params, z_composed = pl_module(batch, decode=False, compute_loss=False) mu_vars = z_parts_params[x_label] if mu_vars.shape[1] != pl_module.latent_dim: mus = mu_vars[:, : int(mu_vars.shape[1] / 2)] else: mus = mu_vars start = _bs * index end = start + len(mus) _embeddings[start:end] = mus.cpu().numpy() if _ids is not None: _ids[start:end] = batch[id_label].detach().cpu().squeeze() _split[start:end] = [split_name] * len(mus) diff = _bs - len(batch) if diff > 0: # if last batch is smaller discard the difference _embeddings = _embeddings[:-diff] if _ids is not None: _ids = _ids[:-diff] _split = _split[:-diff] all_embeddings.append(_embeddings) cell_ids.append(_ids) split.append(_split) all_embeddings = np.vstack(all_embeddings) cell_ids = np.hstack(cell_ids) if cell_ids[0] is not None else None split = np.hstack(split) df = pd.DataFrame(all_embeddings, columns=[f"mu_{i}" for i in range(all_embeddings.shape[1])]) df["split"] = split if cell_ids is not None: df["CellId"] = cell_ids return df