Source code for cyto_dl.callbacks.csv_saver

from pathlib import Path

import pandas as pd
from lightning.pytorch.callbacks import Callback


[docs]class CSVSaver(Callback): def __init__(self, save_dir, meta_keys=[]): self.save_dir = Path(save_dir) self.save_dir.mkdir(parents=True, exist_ok=True) self.meta_keys = meta_keys
[docs] def pred_to_csv(self, pred): return pd.DataFrame(pred)
[docs] def save_feats(self, predictions, stage): feats = [] for pred, meta in predictions: batch_feats = self.pred_to_csv(pred) for k in self.meta_keys: batch_feats[k] = meta[k] feats.append(batch_feats) pd.concat(feats).to_csv(self.save_dir / f"{stage}.csv", index=False)
[docs] def on_predict_epoch_end(self, trainer, pl_module): # Access the list of predictions from all predict_steps predictions = trainer.predict_loop.predictions self.save_feats(predictions, "predict")
[docs]class JEPASaver(CSVSaver):
[docs] def pred_to_csv(self, pred): source_embed, target_embed, pred_target_embed = pred source_feats = pd.DataFrame(source_embed) source_feats["feat_type"] = "source" target_feats = pd.DataFrame(target_embed) target_feats["feat_type"] = "target" pred_feats = pd.DataFrame(pred_target_embed) pred_feats["feat_type"] = "pred" all_feats = pd.concat([source_feats, target_feats, pred_feats]) return all_feats