cyto_dl.models.vae.base_vae module#
- class cyto_dl.models.vae.base_vae.BaseVAE(*args, **kwargs)[source]#
Bases:
BaseModel
Instantiate a basic VAE model.
- Parameters:
encoder (nn.Module) – Encoder network
decoder (nn.Module) – Decoder network
x_label (Optional[str] = None)
id_label (Optional[str] = None)
beta (float = 1.0) – Beta parameter - the weight of the KLD term in the loss function
reconstruction_loss (Loss) – Loss to be used for reconstruction. Can be a PyTorch loss or a class that respects the same interface, i.e. subclasses torch.nn.modules._Loss
prior (Optional[Sequence[AbstractPrior]]) – List of prior specifications to use for latent space
decoder_latent_parts (Optional[Dict[str, Sequence[str]]] = None) – Dictionary that specifies for each output part’s decoer, what latent keys it depends on
**base_kwargs – Additional arguments passed to BaseModel