cyto_dl.models.jepa.jepa_base module#
- class cyto_dl.models.jepa.jepa_base.JEPABase(*args, **kwargs)[source]#
Bases:
BaseModel
Base for IJEPA and IWM models.
- Parameters:
encoder (nn.Module) – The encoder module used for feature extraction.
predictor (nn.Module) – The predictor module used for generating predictions.
x_key (str) – The key used to access the input data.
momentum (float, optional) – The momentum value for the exponential moving average of the model weights (default is 0.998).
max_epochs (int, optional) – The maximum number of training epochs (default is 100).
**base_kwargs (dict) – Additional arguments passed to the BaseModel.