cyto_dl.models.contrastive.contrastive module#

class cyto_dl.models.contrastive.contrastive.Contrastive(*args, **kwargs)[source]#

Bases: BaseModel

Parameters:
  • backbone (nn.Module) – Backbone model

  • task_head (nn.Module) – Task head model

  • anchor_key (str) – Key in batch dictionary for anchor image

  • positive_key (str) – Key in batch dictionary for positive image

  • target_key (str) – OPTIONAL Key in batch dictionary for target, used only for visualization

  • meta_keys (list[str]) – List of keys in batch dictionary to save to csv during prediction

  • save_dir (str) – Directory to save visualizations

  • viz_freq (int) – Frequency to save visualizations

  • **base_kwargs – Additional arguments passed to BaseModel

forward(x1, x2)[source]#
model_step(stage, batch, batch_idx)[source]#
plot_classes(predictions, labels)[source]#
plot_neighbors(embedding1, embedding2)[source]#
predict_step(batch, batch_idx)[source]#