cyto_dl.models.vae.base_vae module#

class cyto_dl.models.vae.base_vae.BaseVAE(*args, **kwargs)[source]#

Bases: BaseModel

Instantiate a basic VAE model. :Parameters: * encoder (nn.Module) – Encoder network

  • decoder (nn.Module) – Decoder network

  • x_label (Optional[str] = None)

  • id_label (Optional[str] = None)

  • beta (float = 1.0) – Beta parameter - the weight of the KLD term in the loss function

  • reconstruction_loss (Loss) – Loss to be used for reconstruction. Can be a PyTorch loss or a class that respects the same interface, i.e. subclasses torch.nn.modules._Loss

  • prior (Optional[Sequence[AbstractPrior]]) – List of prior specifications to use for latent space

  • decoder_latent_parts (Optional[Dict[str, Sequence[str]]] = None) – Dictionary that specifies for each output part’s decoer, what latent keys it depends on

  • **base_kwargs – Additional arguments passed to BaseModel

calculate_elbo(x, xhat, z)[source]#
calculate_rcl(x, xhat, input_key, target_key=None)[source]#
calculate_rcl_dict(x, xhat, z)[source]#
decode(z)[source]#
encode(batch, **kwargs)[source]#
forward(batch, decode=False, inference=True, return_params=False, **kwargs)[source]#
model_step(stage, batch, batch_idx)[source]#
sample_z(z_parts_params, inference=False)[source]#