cyto_dl.models.jepa.iwm module#
- class cyto_dl.models.jepa.iwm.IWM(*args, **kwargs)[source]#
Bases:
JEPABase
Image World Model for self-supervised learning of encoder and predictor of image translation in latent space.
- Parameters:
encoder (nn.Module) – The encoder module used for feature extraction.
predictor (nn.Module) – The predictor module used for generating predictions.
source_key (str) – The key used to access the source data.
target_key (str) – The key used to access the target data.
target_domain_key (str) – The key used to access the target domain data.
save_dir (str, optional) – The directory to save the model predictions (default is “./”).
momentum (float, optional) – The momentum value for the exponential moving average of the model weights (default is 0.998).
max_epochs (int, optional) – The maximum number of training epochs (default is 100).
predict_target (bool, optional) – Whether to predict the target embeddings instead of just extracting embeddings of source image (default is False).
**base_kwargs (dict) – Additional arguments passed to the BaseModel.