cyto_dl.models.vae.priors.gaussian module#

class cyto_dl.models.vae.priors.gaussian.DiagonalGaussianPrior(dimensionality=None, mean=None, logvar=None, learn_mean=False, learn_logvar=False, clamp_logvar=8.0, tc_penalty_weight=None)[source]#

Bases: IsotropicGaussianPrior

forward(z, mode='kl', inference=False, **kwargs)[source]#
classmethod kl_divergence(mu1, mu2, logvar1, logvar2, tc_penalty_weight=None, reduction='sum')[source]#

Computes the Kullback-Leibler divergence between two diagonal gaussians (not necessarily isotropic). It also works batch-wise.

Parameters:
  • mu1 (torch.Tensor) – Mean of the first gaussian (or batch of first gaussians)

  • mu2 (torch.Tensor) – Mean of the second gaussian (or batch of second gaussians)

  • logvar1 (torch.Tensor) – Log-variance of the first gaussian (or batch of first gaussians)

  • logvar2 (torch.Tensor) – Log-variance of the second gaussian (or batch of second gaussians)

property param_size#
class cyto_dl.models.vae.priors.gaussian.IsotropicGaussianPrior(*, dimensionality, clamp_logvar=8, tc_penalty_weight=None)[source]#

Bases: Prior

forward(z, mode='kl', inference=False, **kwargs)[source]#
classmethod kl_divergence(mean, logvar, tc_penalty_weight=None, reduction='sum')[source]#

Computes the Kullback-Leibler divergence between a diagonal gaussian and an isotropic (0,1) gaussian. It also works batch-wise.

Parameters:
  • mean (torch.Tensor) – Mean of the gaussian (or batch of gaussians)

  • logvar (torch.Tensor) – Log-variance of the gaussian (or batch of gaussians)

property param_size#
classmethod sample(mean, logvar)[source]#
cyto_dl.models.vae.priors.gaussian.compute_tc_penalty(logvar)[source]#