cyto_dl.models.jepa.ijepa module#

class cyto_dl.models.jepa.ijepa.IJEPA(*args, **kwargs)[source]#

Bases: JEPABase

JEPA for self-supervised learning on 2D and 3D images.

Parameters:
  • encoder (nn.Module) – The encoder module used for feature extraction.

  • predictor (nn.Module) – The predictor module used for generating predictions.

  • x_key (str) – The key used to access the input data.

  • momentum (float, optional) – The momentum value for the exponential moving average of the model weights (default is 0.998).

  • max_epochs (int, optional) – The maximum number of training epochs (default is 100).

  • **base_kwargs (dict) – Additional arguments passed to the BaseModel.

model_step(stage, batch, batch_idx)[source]#