Source code for cyto_dl.models.classification.timepoint_classification

from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from einops import rearrange

from cyto_dl.models.classification import Classification
from cyto_dl.models.utils import find_indices


[docs]class TimepointClassification(Classification): def __init__( self, *, model: nn.Module, x_key: str, num_classes: int, y_key: str = "label", save_dir="./", save_movie: bool = True, save_images_every_n_epochs=10, compile=False, write_batch_predictions=False, **base_kwargs, ): super().__init__( model=model, x_key=x_key, num_classes=num_classes, y_key=y_key, save_dir=save_dir, save_images_every_n_epochs=save_images_every_n_epochs, compile=compile, write_batch_predictions=write_batch_predictions, **base_kwargs, )
[docs] def predict_step(self, batch, batch_idx): x = rearrange(batch[self.hparams.x_key], "b c h w -> c b h w") logits = self(x).squeeze(0) preds = torch.argmax(logits, dim=1).cpu().numpy() if self.hparams.write_batch_predictions: pd.DataFrame([preds]).to_csv( Path(self.hparams.save_dir) / f"predictions_batch={batch_idx}.csv", index=False ) if self.hparams.save_movie: self.save_images( batch, "predict", logits, name=f"{batch['track_id'].cpu().item()}", ) timepoints = np.array(batch["timepoints"][0][1:-1].split(",")).astype(int) track_midpoint = (timepoints[0] + timepoints[-1]) // 2 # breakdowns are transitions from interphase (0) to mitotic (1) breakdowns = find_indices(preds, [0, 1]) # formations are transitions from mitotic (1) to interphase (0) # add 1 because the formation index is after index of transition formations = find_indices(preds, [1, 0]) + 1 # -1 -> no formation/breakdown if formations.size == 0: formation = -1 else: # when multiple formations present, take first, indexing into timepoints formation = timepoints[np.min(formations)] # formation should occur in the first half of the track formation = formation if formation < track_midpoint else -1 if breakdowns.size == 0: breakdown = -1 else: # when multiple breakdowns present, take last, indexing into timepoints breakdown = timepoints[np.max(breakdowns)] # breakdown should occur in the second half of the track breakdown = breakdown if breakdown > track_midpoint else -1 predictions = { "track_id": batch["track_id"].cpu().item(), "formation": formation, "breakdown": breakdown, "timepoints": timepoints, } return predictions