Source code for cyto_dl.callbacks.latent_walk_diffae

from typing import Optional
from warnings import warn

import cv2
import numpy as np
import torch
from bioio.writers import OmeTiffWriter
from lightning.pytorch.callbacks import Callback
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

from cyto_dl.models.im2im.utils.postprocessing import detach


[docs]class DiffAELatentWalk(Callback): def __init__( self, num_pcs: int = 8, n_steps: int = 10, sigma_range: Optional[int] = None, every_n_epoch: int = 1, n_noise_samples: int = 1, average: bool = True, batch_size: int = 3, ): """ Parameters ---------- num_pcs: int=8 Number of principal components to use for latent walk n_steps: int=10 Number of steps to traverse each PC in the latent walk sigma_range: Optional[int]=None Range to traverse each PC in the latent walk. If None, the min and max of the PC are used. every_n_epoch:int=1 Frequency to perform latent walk n_noise_samples: int=1 Number of noise samples to generate for each latent walk step average: bool=True Whether to average the generated images batch_size: int=3 Batch size for generating images to prevent GPU OOM """ self.num_pcs = num_pcs self.n_steps = n_steps self.sigma_range = int(sigma_range) if sigma_range is not None else None self.every_n_epoch = every_n_epoch self.n_noise_samples = n_noise_samples self.average = average self.batch_size = batch_size self.pca = Pipeline([("pca", PCA(n_components=num_pcs)), ("scaler", StandardScaler())]) self.val_feats = [] def _write_text(self, img, text): font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.5 color = tuple([img.max()] * 3) thickness = 1 text_size = cv2.getTextSize(text, font, font_scale, thickness)[0] text_x = img.shape[1] - text_size[0] - 3 # 3 pixels from the right edge text_y = text_size[1] + 3 # 3 pixels from the top edge cv2.putText(img, text, (text_x, text_y), font, font_scale, color, thickness) return img def _write_pc_vals(self, walk_img, ranges): """Write PC index and value on image.""" idx = 0 for i, range_ in enumerate(ranges): for val in range_: walk_img[idx] = self._write_text(walk_img[idx], f"PC{i+1}:{val:.1f}") idx += 1 return walk_img def _latent_walk(self, feats, model, save_path): # catch if only one batch for validation if len(feats.shape) == 1 or feats.shape[0] < self.num_pcs: warn(f"Insufficient data for latent walk with {self.num_pcs} PCs. Skipping...") return pca_data = self.pca.fit_transform(feats) print(f"Explained variance ratio: {self.pca['pca'].explained_variance_ratio_}") walk = [] ranges = [] for pc in np.arange(self.num_pcs): std = pca_data[:, pc].std() if self.sigma_range is None: min = pca_data[:, pc].min() / std max = pca_data[:, pc].max() / std range_ = np.linspace(min, max, self.n_steps) else: range_ = np.arange(-self.sigma_range, self.sigma_range + 0.01) print(f"PC{pc} range: {range_}") for i in range_: array = np.zeros(self.num_pcs) array[pc] = i * std walk.append(array) ranges.append(range_) walk = np.stack(walk).squeeze() walk = self.pca.inverse_transform(walk) walk = torch.from_numpy(walk).float().to(model.device) walk_img = model.generate_from_latent( walk, n_noise_samples=self.n_noise_samples, average=self.average, save=False, batch_size=self.batch_size, ) walk_img = self._write_pc_vals(walk_img, ranges) OmeTiffWriter.save(uri=save_path, data=walk_img)
[docs] def on_validation_batch_end( self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0 ): if (trainer.current_epoch + 1) % self.every_n_epoch == 0: latent_feat = detach(outputs[1].squeeze(-1)) self.val_feats.append(latent_feat)
[docs] def on_validation_epoch_end(self, trainer, pl_module): if (trainer.current_epoch + 1) % self.every_n_epoch == 0: # aggregate all latent features for PCA feats = np.concatenate(self.val_feats) self._latent_walk( feats, trainer.model, f"{pl_module.hparams.save_dir}/{trainer.current_epoch+1}_latent_walk.tiff", )
[docs] def on_predict_epoch_end(self, trainer, pl_module): feats = np.concatenate([x[0] for x in trainer.predict_loop.predictions]) self._latent_walk(feats, trainer.model, f"{pl_module.hparams.save_dir}/latent_walk.tiff")