Source code for cyto_dl.models.vae.priors.joint_prior
import torch
import torch.nn as nn
from .abstract_prior import Prior
[docs]class JointPrior(Prior):
def __init__(self, priors):
dimensionality = sum(prior.dimensionality for prior in priors)
super().__init__(dimensionality)
if not isinstance(priors, nn.ModuleList):
priors = nn.ModuleList(*priors)
self.priors = priors
@property
def param_size(self):
return sum(_.param_size for _ in self.priors)
[docs] def kl_divergence(self, z_params, reduction="sum"):
kl = []
start_ix = 0
for prior in self.priors:
end_ix = (start_ix + prior.dimensionality) - 1
kl.append(prior(z_params[:, start_ix:end_ix], mode="kl", reduction=reduction))
start_ix = end_ix + 1
kl = torch.cat(kl, axis=1)
if reduction == "none":
return kl
elif reduction == "sum":
return kl.sum(dim=-1)
else:
raise NotImplementedError(f"Reduction '{reduction}' not implemented for JointPrior")
[docs] def sample(self, z_params, inference=False):
samples = []
start_ix = 0
for prior in self.priors:
end_ix = (start_ix + prior.dimensionality) - 1
samples.append(prior(z_params[:, start_ix:end_ix], mode="sample", inference=inference))
start_ix = end_ix + 1
return torch.cat(samples, axis=1)
[docs] def forward(self, z_params, mode="kl", inference=False, **kwargs):
if mode == "kl":
return self.kl_divergence(z_params)
return self.sample(z_params, inference=inference)