import torch.nn as nn
from cyto_dl.models.jepa import JEPABase
[docs]class IJEPA(JEPABase):
    def __init__(
        self,
        *,
        encoder: nn.Module,
        predictor: nn.Module,
        x_key: str,
        save_dir: str = "./",
        momentum: float = 0.998,
        max_epochs: int = 100,
        **base_kwargs,
    ):
        """JEPA for self-supervised learning on 2D and 3D images.
        Parameters
        ----------
        encoder : nn.Module
            The encoder module used for feature extraction.
        predictor : nn.Module
            The predictor module used for generating predictions.
        x_key : str
            The key used to access the input data.
        momentum : float, optional
            The momentum value for the exponential moving average of the model weights (default is 0.998).
        max_epochs : int, optional
            The maximum number of training epochs (default is 100).
        **base_kwargs : dict
            Additional arguments passed to the BaseModel.
        """
        super().__init__(
            encoder=encoder,
            predictor=predictor,
            x_key=x_key,
            save_dir=save_dir,
            momentum=momentum,
            max_epochs=max_epochs,
            **base_kwargs,
        )
[docs]    def model_step(self, stage, batch, batch_idx):
        self.update_teacher()
        input = batch[self.hparams.x_key]
        input = self.remove_first_dim(input)
        target_masks = self.get_mask(batch, "target_mask")
        context_masks = self.get_mask(batch, "context_mask")
        target_embeddings = self.get_target_embeddings(input, target_masks)
        context_embeddings = self.get_context_embeddings(input, context_masks)
        predictions = self.predictor(context_embeddings, target_masks)
        loss = self.loss(predictions, target_embeddings)
        return loss, None, None