Source code for cyto_dl.models.vae.priors.identity_prior

import torch

from .abstract_prior import Prior


[docs]class IdentityPrior(Prior): """Prior class that doesn't contribute to KL loss. Effectively a Dirac delta distribution given z. """
[docs] def forward(self, z, mode="kl", **kwargs): if mode == "kl": return torch.tensor(0.0).type_as(z) return z
@property def param_size(self): return self.dimensionality