Source code for cyto_dl.models.im2im.diffusion_autoencoder

import math
from typing import Optional, Sequence

import torch
import torch.nn as nn
import tqdm
from bioio.writers import OmeTiffWriter
from monai.inferers import Inferer
from monai.networks.schedulers import NoiseSchedules
from monai.networks.schedulers.ddim import Scheduler
from monai.utils import convert_to_tensor
from torchmetrics import MeanMetric

from cyto_dl.models.base_model import BaseModel
from cyto_dl.models.im2im.utils import detach


[docs]@NoiseSchedules.add_def("inverse_cosine", "Inverse cosine beta schedule") def inverted_cosine_beta_schedule(num_train_timesteps, s=0.008): """ inverted cosine schedule as proposed in https://arxiv.org/pdf/2311.17901.pdf """ steps = num_train_timesteps + 1 t = torch.linspace(0, num_train_timesteps, steps) / num_train_timesteps alphas_cumprod = (2 * (1 + s) / math.pi) * torch.arccos(torch.sqrt(t)) - s alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0.0001, 0.9999)
[docs]class DiffusionAutoEncoder(BaseModel): """ [DiffusionAutoencoder](https://arxiv.org/abs/2111.15640) for representation learning. Code is based on the [MONAI generative tutorial](https://github.com/Project-MONAI/GenerativeModels/blob/main/tutorials/generative/2d_diffusion_autoencoder/2d_diffusion_autoencoder_tutorial.ipynb) """ def __init__( self, *, autoencoder: nn.Module, spatial_inferer: Inferer, image_shape: Sequence[int], condition_key: str, noise_scheduler: Scheduler, diffusion_inferer: Inferer, loss: nn.Module = nn.MSELoss(), semantic_encoder: nn.Module = None, diffusion_key: Optional[str] = None, n_inference_steps: int = 50, save_dir="./", save_images_every_n_epochs: int = 1, n_noise_samples: Optional[int] = 1, train_encoder: bool = True, gamma: float = -1.0, **base_kwargs, ): """ Parameters ---------- autoencoder: nn.Module model network to denoise the diffusion image (conditioned on the latent generated by the semantic encoder) spatial_inferer: Inferer Inferer to use for splitting large images into patches during inference image_shape: Sequence[int] C[Z]YX shape of the input images condition_key: str key to access condition images in batch noise_scheduler: Scheduler beta noise scheduler diffusion_inferer: Inferer Inferer to use for diffusion sampling loss: nn.Module loss function to use for training. Should have no reduction. semantic_encoder: nn.Module model network to encode the condition image diffusion_key: Optional[str] key to access diffusion images in batch. If None, defaults to condition_key n_inference_steps: int number of noise steps used during inference. Must be less than the number of train steps used in your noise scheduler, and can be much fewer due to DDIM sampling save_dir="./" directory to save images during training and validation save_images_every_n_epochs: int Image saving frequency n_noise_samples: Optional[int] Number of noise samples to average for latent walk train_encoder: bool Whether to train the semantic encoder gamma: float Minimum SNR for loss weighting. If negative, no weighting is applied noise_schedule: str beta noise schedule. Options are 'inverse_cosine' from the SODA paper or see the MONAI docs for other options **base_kwargs: Additional arguments passed to BaseModel """ _DEFAULT_METRICS = { "train/loss": MeanMetric(), "val/loss": MeanMetric(), "test/loss": MeanMetric(), } metrics = base_kwargs.pop("metrics", _DEFAULT_METRICS) super().__init__(metrics=metrics, **base_kwargs) self.diffusion_key = diffusion_key or condition_key self.autoencoder = autoencoder self.semantic_encoder = semantic_encoder if not train_encoder: for param in self.semantic_encoder.parameters(): param.requires_grad = False self.scheduler = noise_scheduler self.weights = self.scheduler.alphas_cumprod self.inferer = diffusion_inferer(self.scheduler) self.spatial_inferer = spatial_inferer if gamma > 0 and (not hasattr(loss, "reduction") or loss.reduction != "none"): raise ValueError("Loss must have reduction='none'if using loss weighting (gamma > 0)") self.loss = loss
[docs] def configure_optimizers(self): params = list(self.autoencoder.parameters()) if self.hparams.train_encoder: params += list(self.semantic_encoder.parameters()) opt = self.optimizer(params) sched = self.lr_scheduler(optimizer=opt) return [opt], [sched]
def _get_loss_weight(self, timesteps): """ Min-SNR weighting strategy from https://arxiv.org/pdf/2303.09556 """ if self.hparams.gamma < 0: return None self.weights = self.weights.to(timesteps.device) alpha_prod_t = self.weights[timesteps] beta_prod_t = 1 - alpha_prod_t alpha = alpha_prod_t**0.5 sigma = beta_prod_t**0.5 snr = (alpha / sigma) ** 2 min_snr = torch.clip(snr, max=self.hparams.gamma) weight = (min_snr / snr).view(-1, 1, 1, 1) return weight
[docs] def forward(self, x_cond, x_diff): noise = torch.randn_like(x_diff, device=x_diff.device) timesteps = torch.randint( 0, self.inferer.scheduler.num_train_timesteps, (x_diff.shape[0],), device=x_diff.device, dtype=torch.long, ) loss_weight = self._get_loss_weight(timesteps) # latent is B x C x 1 latent = self.semantic_encoder(x_cond).unsqueeze(2) noise_pred = self.inferer( inputs=x_diff, diffusion_model=self.autoencoder, noise=noise, timesteps=timesteps, condition=latent, ) return noise, noise_pred, latent, loss_weight
def _generate_image(self, noise, cond): self.scheduler.set_timesteps(num_inference_steps=self.hparams.n_inference_steps) with torch.no_grad(): sample = self.inferer.sample( input_noise=noise, diffusion_model=self.autoencoder, scheduler=self.scheduler, conditioning=cond, verbose=False, ) return sample
[docs] def save_example(self, stage, cond_img, diff_img): """Save the sequence of denoising steps.""" with torch.no_grad(): cond = self.semantic_encoder(cond_img).unsqueeze(2) noise = torch.randn_like(diff_img, device=self.device) sample = self._generate_image(noise, cond) for img, name in zip([cond_img, diff_img, sample], ["cond", "diff", "recon"]): OmeTiffWriter.save( uri=f"{self.hparams.save_dir}/{self.trainer.current_epoch}_{stage}_{name}.tiff", data=detach(img).astype(float), )
[docs] def model_step(self, stage, batch, batch_idx): batch = convert_to_tensor(batch) cond_img = batch[self.hparams.condition_key] diff_img = batch[self.diffusion_key] noise, noise_pred, latent, loss_weight = self.forward(cond_img, diff_img) if ( (self.trainer.current_epoch + 1) % self.hparams.save_images_every_n_epochs ) == batch_idx == 0 and stage == "val": self.save_example(stage, cond_img[:1], diff_img[:1]) diffusion_loss = self.loss(noise, noise_pred) if loss_weight is not None: diffusion_loss = torch.mean(diffusion_loss * loss_weight) return {"loss": diffusion_loss}, latent, None
[docs] def validation_step(self, batch, batch_idx): loss, preds, targets = self.model_step("val", batch, batch_idx) self.compute_metrics(loss, preds, targets, "val") return loss, preds
[docs] def generate_from_latent( self, cond: torch.Tensor, save_name: str = "generated_image", n_noise_samples: Optional[int] = None, average: bool = True, save: bool = True, batch_size: int = 3, ): """Generate images from latent features. If average is True, average over n_noise_samples, otherwise make a composite. Parameters: ---------- cond: torch.Tensor latent features to condition the diffusion model save_name: str name to save the generated image n_noise_samples: int number of noise samples to generate average: bool Whether to average the generated images. If False, composite the images. save: bool Whether to save the generated image. If False, return the generated image. batch_size: int batch size for generating images """ if batch_size <= 0: raise ValueError("Batch size must be at least 1") batch_indices = [(i, i + batch_size) for i in range(0, cond.shape[0], batch_size)] n_noise_samples = n_noise_samples or self.hparams.n_noise_samples with torch.no_grad(): recon = None for _ in tqdm.tqdm(range(n_noise_samples), desc="Sampling"): # keep noise constant across walk for consistency noise = torch.stack( [torch.randn(self.hparams.image_shape, device=self.device)] * cond.shape[0] ) sample = torch.cat( [ self._generate_image( noise[start:stop], cond[start:stop].unsqueeze(2) ).squeeze(1) for start, stop in batch_indices ], 0, ) sample = sample if average else [sample] recon = sample if recon is None else recon + sample if average: recon /= self.hparams.n_noise_samples else: recon = torch.cat(recon, -1) recon = detach(recon).astype(float) if save: OmeTiffWriter.save(uri=f"{self.hparams.save_dir}/{save_name}.tiff", data=recon) return recon
[docs] def encode_image(self, x): with torch.no_grad(): z, loc = self.spatial_inferer(x, self.semantic_encoder) return z, loc
[docs] def predict_step(self, batch, batch_idx): meta = batch[self.hparams.condition_key].meta batch = convert_to_tensor(batch) z, loc = self.encode_image(batch[self.hparams.condition_key]) meta.update(loc) return detach(z), meta