import math
from typing import Optional, Sequence
import torch
import torch.nn as nn
import tqdm
from bioio.writers import OmeTiffWriter
from monai.inferers import Inferer
from monai.networks.schedulers import NoiseSchedules
from monai.networks.schedulers.ddim import Scheduler
from monai.utils import convert_to_tensor
from torchmetrics import MeanMetric
from cyto_dl.models.base_model import BaseModel
from cyto_dl.models.im2im.utils import detach
[docs]@NoiseSchedules.add_def("inverse_cosine", "Inverse cosine beta schedule")
def inverted_cosine_beta_schedule(num_train_timesteps, s=0.008):
"""
inverted cosine schedule
as proposed in https://arxiv.org/pdf/2311.17901.pdf
"""
steps = num_train_timesteps + 1
t = torch.linspace(0, num_train_timesteps, steps) / num_train_timesteps
alphas_cumprod = (2 * (1 + s) / math.pi) * torch.arccos(torch.sqrt(t)) - s
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
[docs]class DiffusionAutoEncoder(BaseModel):
"""
[DiffusionAutoencoder](https://arxiv.org/abs/2111.15640) for representation learning. Code is based on the [MONAI generative tutorial](https://github.com/Project-MONAI/GenerativeModels/blob/main/tutorials/generative/2d_diffusion_autoencoder/2d_diffusion_autoencoder_tutorial.ipynb)
"""
def __init__(
self,
*,
autoencoder: nn.Module,
spatial_inferer: Inferer,
image_shape: Sequence[int],
condition_key: str,
noise_scheduler: Scheduler,
diffusion_inferer: Inferer,
loss: nn.Module = nn.MSELoss(),
semantic_encoder: nn.Module = None,
diffusion_key: Optional[str] = None,
n_inference_steps: int = 50,
save_dir="./",
save_images_every_n_epochs: int = 1,
n_noise_samples: Optional[int] = 1,
train_encoder: bool = True,
gamma: float = -1.0,
**base_kwargs,
):
"""
Parameters
----------
autoencoder: nn.Module
model network to denoise the diffusion image (conditioned on the latent generated by the semantic encoder)
spatial_inferer: Inferer
Inferer to use for splitting large images into patches during inference
image_shape: Sequence[int]
C[Z]YX shape of the input images
condition_key: str
key to access condition images in batch
noise_scheduler: Scheduler
beta noise scheduler
diffusion_inferer: Inferer
Inferer to use for diffusion sampling
loss: nn.Module
loss function to use for training. Should have no reduction.
semantic_encoder: nn.Module
model network to encode the condition image
diffusion_key: Optional[str]
key to access diffusion images in batch. If None, defaults to condition_key
n_inference_steps: int
number of noise steps used during inference. Must be less than the number of train steps used in your noise scheduler, and can be much fewer due to DDIM sampling
save_dir="./"
directory to save images during training and validation
save_images_every_n_epochs: int
Image saving frequency
n_noise_samples: Optional[int]
Number of noise samples to average for latent walk
train_encoder: bool
Whether to train the semantic encoder
gamma: float
Minimum SNR for loss weighting. If negative, no weighting is applied
noise_schedule: str
beta noise schedule. Options are 'inverse_cosine' from the SODA paper or see the MONAI docs for other options
**base_kwargs:
Additional arguments passed to BaseModel
"""
_DEFAULT_METRICS = {
"train/loss": MeanMetric(),
"val/loss": MeanMetric(),
"test/loss": MeanMetric(),
}
metrics = base_kwargs.pop("metrics", _DEFAULT_METRICS)
super().__init__(metrics=metrics, **base_kwargs)
self.diffusion_key = diffusion_key or condition_key
self.autoencoder = autoencoder
self.semantic_encoder = semantic_encoder
if not train_encoder:
for param in self.semantic_encoder.parameters():
param.requires_grad = False
self.scheduler = noise_scheduler
self.weights = self.scheduler.alphas_cumprod
self.inferer = diffusion_inferer(self.scheduler)
self.spatial_inferer = spatial_inferer
if gamma > 0 and (not hasattr(loss, "reduction") or loss.reduction != "none"):
raise ValueError("Loss must have reduction='none'if using loss weighting (gamma > 0)")
self.loss = loss
def _get_loss_weight(self, timesteps):
"""
Min-SNR weighting strategy from https://arxiv.org/pdf/2303.09556
"""
if self.hparams.gamma < 0:
return None
self.weights = self.weights.to(timesteps.device)
alpha_prod_t = self.weights[timesteps]
beta_prod_t = 1 - alpha_prod_t
alpha = alpha_prod_t**0.5
sigma = beta_prod_t**0.5
snr = (alpha / sigma) ** 2
min_snr = torch.clip(snr, max=self.hparams.gamma)
weight = (min_snr / snr).view(-1, 1, 1, 1)
return weight
[docs] def forward(self, x_cond, x_diff):
noise = torch.randn_like(x_diff, device=x_diff.device)
timesteps = torch.randint(
0,
self.inferer.scheduler.num_train_timesteps,
(x_diff.shape[0],),
device=x_diff.device,
dtype=torch.long,
)
loss_weight = self._get_loss_weight(timesteps)
# latent is B x C x 1
latent = self.semantic_encoder(x_cond).unsqueeze(2)
noise_pred = self.inferer(
inputs=x_diff,
diffusion_model=self.autoencoder,
noise=noise,
timesteps=timesteps,
condition=latent,
)
return noise, noise_pred, latent, loss_weight
def _generate_image(self, noise, cond):
self.scheduler.set_timesteps(num_inference_steps=self.hparams.n_inference_steps)
with torch.no_grad():
sample = self.inferer.sample(
input_noise=noise,
diffusion_model=self.autoencoder,
scheduler=self.scheduler,
conditioning=cond,
verbose=False,
)
return sample
[docs] def save_example(self, stage, cond_img, diff_img):
"""Save the sequence of denoising steps."""
with torch.no_grad():
cond = self.semantic_encoder(cond_img).unsqueeze(2)
noise = torch.randn_like(diff_img, device=self.device)
sample = self._generate_image(noise, cond)
for img, name in zip([cond_img, diff_img, sample], ["cond", "diff", "recon"]):
OmeTiffWriter.save(
uri=f"{self.hparams.save_dir}/{self.trainer.current_epoch}_{stage}_{name}.tiff",
data=detach(img).astype(float),
)
[docs] def model_step(self, stage, batch, batch_idx):
batch = convert_to_tensor(batch)
cond_img = batch[self.hparams.condition_key]
diff_img = batch[self.diffusion_key]
noise, noise_pred, latent, loss_weight = self.forward(cond_img, diff_img)
if (
(self.trainer.current_epoch + 1) % self.hparams.save_images_every_n_epochs
) == batch_idx == 0 and stage == "val":
self.save_example(stage, cond_img[:1], diff_img[:1])
diffusion_loss = self.loss(noise, noise_pred)
if loss_weight is not None:
diffusion_loss = torch.mean(diffusion_loss * loss_weight)
return {"loss": diffusion_loss}, latent, None
[docs] def validation_step(self, batch, batch_idx):
loss, preds, targets = self.model_step("val", batch, batch_idx)
self.compute_metrics(loss, preds, targets, "val")
return loss, preds
[docs] def generate_from_latent(
self,
cond: torch.Tensor,
save_name: str = "generated_image",
n_noise_samples: Optional[int] = None,
average: bool = True,
save: bool = True,
batch_size: int = 3,
):
"""Generate images from latent features. If average is True, average over n_noise_samples,
otherwise make a composite.
Parameters:
----------
cond: torch.Tensor
latent features to condition the diffusion model
save_name: str
name to save the generated image
n_noise_samples: int
number of noise samples to generate
average: bool
Whether to average the generated images. If False, composite the images.
save: bool
Whether to save the generated image. If False, return the generated image.
batch_size: int
batch size for generating images
"""
if batch_size <= 0:
raise ValueError("Batch size must be at least 1")
batch_indices = [(i, i + batch_size) for i in range(0, cond.shape[0], batch_size)]
n_noise_samples = n_noise_samples or self.hparams.n_noise_samples
with torch.no_grad():
recon = None
for _ in tqdm.tqdm(range(n_noise_samples), desc="Sampling"):
# keep noise constant across walk for consistency
noise = torch.stack(
[torch.randn(self.hparams.image_shape, device=self.device)] * cond.shape[0]
)
sample = torch.cat(
[
self._generate_image(
noise[start:stop], cond[start:stop].unsqueeze(2)
).squeeze(1)
for start, stop in batch_indices
],
0,
)
sample = sample if average else [sample]
recon = sample if recon is None else recon + sample
if average:
recon /= self.hparams.n_noise_samples
else:
recon = torch.cat(recon, -1)
recon = detach(recon).astype(float)
if save:
OmeTiffWriter.save(uri=f"{self.hparams.save_dir}/{save_name}.tiff", data=recon)
return recon
[docs] def encode_image(self, x):
with torch.no_grad():
z, loc = self.spatial_inferer(x, self.semantic_encoder)
return z, loc
[docs] def predict_step(self, batch, batch_idx):
meta = batch[self.hparams.condition_key].meta
batch = convert_to_tensor(batch)
z, loc = self.encode_image(batch[self.hparams.condition_key])
meta.update(loc)
return detach(z), meta