cyto_dl.datamodules.dataframe.grouped_dataframe_datamodule module#

class cyto_dl.datamodules.dataframe.grouped_dataframe_datamodule.GroupedDataframeDatamodule(path: UPath | str, transforms: Dict, split_column: str | None = None, columns: Sequence[str] | None = None, split_map: Dict | None = None, just_inference: bool = False, cache_dir: UPath | str | None = None, subsample: Dict | None = None, refresh_subsample: bool = False, seed: int = 42, smartcache_args: Dict | None = None, target_columns: str | None = None, grouping_column: str | None = None, **dataloader_kwargs)[source]#

Bases: DataframeDatamodule

A DataframeDatamodule modified for cases where batches should be grouped by some criterion leveraging an AlternatingBatchSampler.

The two use cases currently supported are 1. multitask training where ground truths are only available for one task at a time and 2. training where batches are grouped by some characteristic of the images.

Parameters:
  • path (Union[Path, str]) – Path to a dataframe file

  • transforms (Dict) – Transforms specifications for each given split.

  • split_column (Optional[str] = None) – Name of a column in the dataset which can be used to create train, val, test splits.

  • columns (Optional[Sequence[str]] = None) – List of columns to load from the dataset, in case it’s a parquet file. If None, load everything.

  • split_map (Optional[Dict] = None) – TODO: document this argument

  • just_inference (bool = False) – Whether this datamodule will be used for just inference (testing/prediction). If so, the splits are ignored and the whole dataset is used.

  • cache_dir (Optional[Union[Path, str]] = None) – Path to a directory in which to store cached transformed inputs, to accelerate batch loading.

  • subsample (Optional[Dict] = None) – Dictionary with a key per split (“train”, “val”, “test”), and the number of samples of each split to use per epoch. If None (default), use all the samples in each split per epoch.

  • refresh_subsample (bool = False) – Whether to refresh subsample each time dataloader is called

  • seed (int = 42) – random seed

  • smartcache_args (Optional[Dict] = None) – Arguments to pass to SmartcacheDataset

  • target_columns (str = None) – column names in csv corresponding to ground truth types to alternate between during training

  • grouping_column (str = None) – column names in csv corresponding to a factor that should be homogeneous across a batch

  • dataloader_kwargs – Additional keyword arguments are passed to the torch.utils.data.DataLoader class when instantiating it (aside from shuffle which is only used for the train dataloader). Among these args are num_workers, batch_size, shuffle, etc. See the PyTorch docs for more info on these args: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

make_dataloader(split)[source]#