cyto_dl.nn.vits.predictor module#
- class cyto_dl.nn.vits.predictor.IWMPredictor(domains: List[str], num_patches: int | List[int], spatial_dims: int = 3, input_dim: int | None = 192, emb_dim: int | None = 192, num_layer: int | None = 12, num_head: int | None = 3)[source]#
Bases:
JEPAPredictor
Specialized JEPA predictor that can conditionally predict between different domains (e.g. predict from brightfield to multiple fluorescent tags)
- Parameters:
domains (List[str]) – List of names of target domains
num_patches (List[int]) – Number of patches in each dimension. If int, the same number of patches is used for all spatial dimensions
spatial_dims (int) – Number of spatial dimensions
spatial_dims (int) – Number of spatial dimensions
emb_dim (int) – Dimension of embedding
num_layer (int) – Number of transformer layers
num_head (int) – Number of heads in transformer
- class cyto_dl.nn.vits.predictor.JEPAPredictor(num_patches: int | List[int], spatial_dims: int = 3, input_dim: int | None = 192, emb_dim: int | None = 192, num_layer: int | None = 12, num_head: int | None = 3, learnable_pos_embedding: bool | None = True)[source]#
Bases:
Module
Class for predicting target features from context embedding.
- Parameters:
num_patches (List[int], int) – Number of patches in each dimension. If int, the same number of patches is used for all spatial dimensions
spatial_dims (int) – Number of spatial dimensions
input_dim (int) – Dimension of input
emb_dim (int) – Dimension of embedding
num_layer (int) – Number of transformer layers
num_head (int) – Number of heads in transformer
learnable_pos_embedding (bool) – If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images.