Source code for cyto_dl.callbacks.tabular_saver

import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from lightning.pytorch.callbacks import Callback

warnings.simplefilter("once", UserWarning)


[docs]class SaveTabularData(Callback): """Callback to save tabular data to disk as a .csv or .parquet after prediction.""" def __init__( self, save_dir, meta_keys=[], as_parquet: bool = True, save_suffix: str = None, col_prefix: str = "feat", ): """ Parameters ---------- save_dir: str directory to save the tabular data meta_keys: list list of keys in the metadata to include as columns in the saved data as_parquet: bool Saves data as parquet if True, otherwise saves as csv save_suffix: str suffix to add to the saved file name col_prefix: str prefix to add to the column names of the saved data """ self.save_dir = Path(save_dir) self.save_dir.mkdir(parents=True, exist_ok=True) self.meta_keys = meta_keys self.as_parquet = as_parquet self.save_suffix = save_suffix self.col_prefix = col_prefix
[docs] def pred_to_df(self, pred): return pd.DataFrame(pred, columns=[f"{self.col_prefix}_{i}" for i in range(pred.shape[1])])
def _parse_meta(self, meta): """Turn tensors in metadata into numpy arrays and single-element tensors/arrays/lists into numbers.""" for k, v in meta.items(): if isinstance(v, torch.Tensor): if v.numel() == 1: v = v.item() else: v = v.numpy() elif isinstance(v, (list, np.ndarray)) and len(v) == 1: v = v[0] meta[k] = v return meta def _save(self, feats, stage): save_name = ( self.save_dir / str(stage) if self.save_suffix is None else self.save_dir / f"{stage}_{self.save_suffix}" ) if self.as_parquet: feats = pd.concat(feats) for col in feats.select_dtypes(include=[np.float16]).columns: feats[col] = feats[col].astype(np.float32) feats.columns = feats.columns.astype(str) feats.to_parquet(str(save_name) + ".parquet") else: pd.concat(feats).to_csv(str(save_name) + ".csv", index=False)
[docs] def save_feats(self, predictions, stage): feats = [] for pred, meta in predictions: meta = self._parse_meta(meta) batch_feats = self.pred_to_df(pred) for k in self.meta_keys: if k in meta: batch_feats[k] = meta[k] else: warnings.warn( f"Metadata key {k} not found in metadata. Available keys are {meta.keys()}" ) feats.append(batch_feats) self._save(feats, stage)
[docs] def on_predict_epoch_end(self, trainer, pl_module): # Access the list of predictions from all predict_steps predictions = trainer.predict_loop.predictions self.save_feats(predictions, "predict")