cyto_dl.callbacks.model_utils module#

class cyto_dl.callbacks.model_utils.GetEmbeddings(x_label: str, id_label: str | None = None)[source]#

Bases: Callback

Args:

x_label: x_label from datamodule id_field: id_field from datamodule

on_test_epoch_end(trainer: Trainer, pl_module: LightningModule)[source]#
cyto_dl.callbacks.model_utils.get_all_embeddings(train_dataloader, val_dataloader, test_dataloader, pl_module: LightningModule, x_label: str, id_label: None)[source]#
cyto_dl.callbacks.model_utils.save_predictions_classifier(preds, output_dir)[source]#

TODO: make this better? maybe use vol predictor code? TODO: drop unnecessary index