cyto_dl.models.contrastive.contrastive module#
- class cyto_dl.models.contrastive.contrastive.Contrastive(*args, **kwargs)[source]#
Bases:
BaseModel
- Parameters:
backbone (nn.Module) – Backbone model
task_head (nn.Module) – Task head model
anchor_key (str) – Key in batch dictionary for anchor image
positive_key (str) – Key in batch dictionary for positive image
target_key (str) – OPTIONAL Key in batch dictionary for target, used only for visualization
meta_keys (list[str]) – List of keys in batch dictionary to save to csv during prediction
save_dir (str) – Directory to save visualizations
viz_freq (int) – Frequency to save visualizations
**base_kwargs – Additional arguments passed to BaseModel