cyto_dl.models.jepa.ijepa module#
- class cyto_dl.models.jepa.ijepa.IJEPA(*args, **kwargs)[source]#
Bases:
JEPABase
JEPA for self-supervised learning on 2D and 3D images.
- Parameters:
encoder (nn.Module) – The encoder module used for feature extraction.
predictor (nn.Module) – The predictor module used for generating predictions.
x_key (str) – The key used to access the input data.
momentum (float, optional) – The momentum value for the exponential moving average of the model weights (default is 0.998).
max_epochs (int, optional) – The maximum number of training epochs (default is 100).
**base_kwargs (dict) – Additional arguments passed to the BaseModel.