cyto_dl.models.jepa.iwm module#

class cyto_dl.models.jepa.iwm.IWM(*args, **kwargs)[source]#

Bases: JEPABase

Image World Model for self-supervised learning of encoder and predictor of image translation in latent space.

Parameters:
  • encoder (nn.Module) – The encoder module used for feature extraction.

  • predictor (nn.Module) – The predictor module used for generating predictions.

  • source_key (str) – The key used to access the source data.

  • target_key (str) – The key used to access the target data.

  • target_domain_key (str) – The key used to access the target domain data.

  • save_dir (str, optional) – The directory to save the model predictions (default is “./”).

  • momentum (float, optional) – The momentum value for the exponential moving average of the model weights (default is 0.998).

  • max_epochs (int, optional) – The maximum number of training epochs (default is 100).

  • predict_target (bool, optional) – Whether to predict the target embeddings instead of just extracting embeddings of source image (default is False).

  • **base_kwargs (dict) – Additional arguments passed to the BaseModel.

extract_embeddings(tensor)[source]#
get_predict_masks(batch_size, device)[source]#
model_step(stage, batch, batch_idx)[source]#
predict_step(batch, batch_idx)[source]#
test_step(batch, batch_idx)[source]#