Source code for cyto_dl.models.jepa.ijepa

import torch.nn as nn

from cyto_dl.models.jepa import JEPABase


[docs]class IJEPA(JEPABase): def __init__( self, *, encoder: nn.Module, predictor: nn.Module, x_key: str, save_dir: str = "./", momentum: float = 0.998, max_epochs: int = 100, **base_kwargs, ): """JEPA for self-supervised learning on 2D and 3D images. Parameters ---------- encoder : nn.Module The encoder module used for feature extraction. predictor : nn.Module The predictor module used for generating predictions. x_key : str The key used to access the input data. momentum : float, optional The momentum value for the exponential moving average of the model weights (default is 0.998). max_epochs : int, optional The maximum number of training epochs (default is 100). **base_kwargs : dict Additional arguments passed to the BaseModel. """ super().__init__( encoder=encoder, predictor=predictor, x_key=x_key, save_dir=save_dir, momentum=momentum, max_epochs=max_epochs, **base_kwargs, )
[docs] def model_step(self, stage, batch, batch_idx): self.update_teacher() input = batch[self.hparams.x_key] input = self.remove_first_dim(input) target_masks = self.get_mask(batch, "target_mask") context_masks = self.get_mask(batch, "context_mask") target_embeddings = self.get_target_embeddings(input, target_masks) context_embeddings = self.get_context_embeddings(input, context_masks) predictions = self.predictor(context_embeddings, target_masks) loss = self.loss(predictions, target_embeddings) return loss, None, None