cyto_dl.nn.vits.predictor module#

class cyto_dl.nn.vits.predictor.IWMPredictor(domains: List[str], num_patches: int | List[int], spatial_dims: int = 3, input_dim: int | None = 192, emb_dim: int | None = 192, num_layer: int | None = 12, num_head: int | None = 3)[source]#

Bases: JEPAPredictor

Specialized JEPA predictor that can conditionally predict between different domains (e.g. predict from brightfield to multiple fluorescent tags)

  • domains (List[str]) – List of names of target domains

  • num_patches (List[int]) – Number of patches in each dimension. If int, the same number of patches is used for all spatial dimensions

  • spatial_dims (int) – Number of spatial dimensions

  • spatial_dims (int) – Number of spatial dimensions

  • emb_dim (int) – Dimension of embedding

  • num_layer (int) – Number of transformer layers

  • num_head (int) – Number of heads in transformer

forward(context_emb, target_masks, target_domain)[source]#
class cyto_dl.nn.vits.predictor.JEPAPredictor(num_patches: int | List[int], spatial_dims: int = 3, input_dim: int | None = 192, emb_dim: int | None = 192, num_layer: int | None = 12, num_head: int | None = 3, learnable_pos_embedding: bool | None = True)[source]#

Bases: Module

Class for predicting target features from context embedding.

  • num_patches (List[int], int) – Number of patches in each dimension. If int, the same number of patches is used for all spatial dimensions

  • spatial_dims (int) – Number of spatial dimensions

  • input_dim (int) – Dimension of input

  • emb_dim (int) – Dimension of embedding

  • num_layer (int) – Number of transformer layers

  • num_head (int) – Number of heads in transformer

  • learnable_pos_embedding (bool) – If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images.

forward(context_emb, target_masks)[source]#
predict_target_features(context_emb, target_masks)[source]#