Source code for cyto_dl.models.vae.priors.identity_prior
import torch
from .abstract_prior import Prior
[docs]class IdentityPrior(Prior):
"""Prior class that doesn't contribute to KL loss.
Effectively a Dirac delta distribution given z.
"""
[docs] def forward(self, z, mode="kl", **kwargs):
if mode == "kl":
return torch.tensor(0.0).type_as(z)
return z
@property
def param_size(self):
return self.dimensionality