Source code for cyto_dl.models.vae.priors.gaussian

import numpy as np
import torch
import torch.nn as nn

from .abstract_prior import Prior


[docs]def compute_tc_penalty(logvar): return (2 * logvar).exp().mean(dim=0).sum()
[docs]class IsotropicGaussianPrior(Prior): def __init__(self, *, dimensionality, clamp_logvar=8, tc_penalty_weight=None): self.tc_penalty_weight = tc_penalty_weight self.clamp_logvar = float(clamp_logvar) super().__init__(dimensionality) @property def param_size(self): return 2 * self.dimensionality
[docs] @classmethod def kl_divergence(cls, mean, logvar, tc_penalty_weight=None, reduction="sum"): """Computes the Kullback-Leibler divergence between a diagonal gaussian and an isotropic (0,1) gaussian. It also works batch-wise. Parameters ---------- mean: torch.Tensor Mean of the gaussian (or batch of gaussians) logvar: torch.Tensor Log-variance of the gaussian (or batch of gaussians) """ kl = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp()) if reduction == "none": loss = kl elif reduction == "mean": loss = kl.mean(dim=-1) else: loss = kl.sum(dim=-1) if tc_penalty_weight is not None and reduction != "none": tc_penalty = compute_tc_penalty(logvar) loss = loss + tc_penalty_weight * tc_penalty return loss
[docs] @classmethod def sample(cls, mean, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std).mul(std).add(mean) return eps
[docs] def forward(self, z, mode="kl", inference=False, **kwargs): mean_logvar = z mean, logvar = torch.split(mean_logvar, mean_logvar.shape[1] // 2, dim=1) if self.clamp_logvar: logvar = torch.clamp(logvar, max=abs(self.clamp_logvar)) if mode == "kl": return self.kl_divergence( mean, logvar, tc_penalty_weight=self.tc_penalty_weight, **kwargs ) if inference: return mean return self.sample(mean, logvar, **kwargs)
[docs]class DiagonalGaussianPrior(IsotropicGaussianPrior): def __init__( self, dimensionality=None, mean=None, logvar=None, learn_mean=False, learn_logvar=False, clamp_logvar=8.0, tc_penalty_weight=None, ): if hasattr(mean, "__len__"): if dimensionality is None: dimensionality = len(mean) else: assert dimensionality == len(mean) if hasattr(logvar, "__len__"): if dimensionality is None: dimensionality = len(logvar) else: assert dimensionality == len(logvar) assert dimensionality is not None super().__init__(clamp_logvar, dimensionality) if logvar is None: logvar = torch.zeros(self.dimensionality) else: if not hasattr(logvar, "__len__"): logvar = [logvar] * self.dimensionality logvar = torch.tensor(np.fromiter(logvar, dtype=float)) if learn_logvar: logvar = nn.Parameter(logvar, requires_grad=True) self.logvar = logvar if mean is None: mean = torch.zeros(self.dimension) else: if not hasattr(mean, "__len__"): logvar = [mean] * self.dimensionality mean = torch.tensor(np.fromiter(mean, dtype=float)) if learn_mean: mean = nn.Parameter(mean, requires_grad=True) self.mean = mean self.tc_penalty_weight = tc_penalty_weight @property def param_size(self): return 2 * self.dimensionality
[docs] @classmethod def kl_divergence(cls, mu1, mu2, logvar1, logvar2, tc_penalty_weight=None, reduction="sum"): """Computes the Kullback-Leibler divergence between two diagonal gaussians (not necessarily isotropic). It also works batch-wise. Parameters ---------- mu1: torch.Tensor Mean of the first gaussian (or batch of first gaussians) mu2: torch.Tensor Mean of the second gaussian (or batch of second gaussians) logvar1: torch.Tensor Log-variance of the first gaussian (or batch of first gaussians) logvar2: torch.Tensor Log-variance of the second gaussian (or batch of second gaussians) """ mu_diff = mu2 - mu1 kl = 0.5 * ( (logvar2 - logvar1) + (logvar1 - logvar2).exp() + (mu_diff.pow(2) / logvar2.exp()) + -1 ) if reduction == "none": return kl loss = kl.sum(dim=-1).mean() if tc_penalty_weight is not None: tc_penalty = compute_tc_penalty(logvar1) loss = loss + tc_penalty_weight * tc_penalty return loss
[docs] def forward(self, z, mode="kl", inference=False, **kwargs): mean_logvar = z mean, logvar = torch.split(mean_logvar, mean_logvar.shape[1] // 2, dim=1) if self.clamp_logvar: logvar = torch.clamp(logvar, max=abs(self.clamp_logvar)) if mode == "kl": return self.kl_divergence( mean, self.mean, logvar, self.logvar, tc_penalty_weight=self.tc_penalty_weight, **kwargs, ) if inference: return mean return self.sample(mean, logvar, **kwargs)