cyto_dl.models.jepa.jepa_base module#

class cyto_dl.models.jepa.jepa_base.JEPABase(*args, **kwargs)[source]#

Bases: BaseModel

Base for IJEPA and IWM models.

Parameters:
  • encoder (nn.Module) – The encoder module used for feature extraction.

  • predictor (nn.Module) – The predictor module used for generating predictions.

  • x_key (str) – The key used to access the input data.

  • momentum (float, optional) – The momentum value for the exponential moving average of the model weights (default is 0.998).

  • max_epochs (int, optional) – The maximum number of training epochs (default is 100).

  • **base_kwargs (dict) – Additional arguments passed to the BaseModel.

configure_optimizers()[source]#
forward(x)[source]#
get_context_embeddings(x, mask)[source]#
get_mask(batch, key)[source]#
get_target_embeddings(x, mask)[source]#
model_step(stage, batch, batch_idx)[source]#
predict_step(batch, batch_idx)[source]#
remove_first_dim(tensor)[source]#
update_teacher()[source]#