Source code for cyto_dl.datamodules.dataframe.grouped_dataframe_datamodule

from typing import Dict, Optional, Sequence, Union

from monai.data import DataLoader
from upath import UPath as Path

from .dataframe_datamodule import DataframeDatamodule
from .utils import AlternatingBatchSampler


[docs]class GroupedDataframeDatamodule(DataframeDatamodule): """A DataframeDatamodule modified for cases where batches should be grouped by some criterion leveraging an AlternatingBatchSampler. The two use cases currently supported are 1. multitask training where ground truths are only available for one task at a time and 2. training where batches are grouped by some characteristic of the images. """ def __init__( self, path: Union[Path, str], transforms: Dict, split_column: Optional[str] = None, columns: Optional[Sequence[str]] = None, split_map: Optional[Dict] = None, just_inference: bool = False, cache_dir: Optional[Union[Path, str]] = None, subsample: Optional[Dict] = None, refresh_subsample: bool = False, seed: int = 42, smartcache_args: Optional[Dict] = None, target_columns: str = None, grouping_column: str = None, **dataloader_kwargs, ): """ Parameters ---------- path: Union[Path, str] Path to a dataframe file transforms: Dict Transforms specifications for each given split. split_column: Optional[str] = None Name of a column in the dataset which can be used to create train, val, test splits. columns: Optional[Sequence[str]] = None List of columns to load from the dataset, in case it's a parquet file. If None, load everything. split_map: Optional[Dict] = None TODO: document this argument just_inference: bool = False Whether this datamodule will be used for just inference (testing/prediction). If so, the splits are ignored and the whole dataset is used. cache_dir: Optional[Union[Path, str]] = None Path to a directory in which to store cached transformed inputs, to accelerate batch loading. subsample: Optional[Dict] = None Dictionary with a key per split ("train", "val", "test"), and the number of samples of each split to use per epoch. If `None` (default), use all the samples in each split per epoch. refresh_subsample: bool = False Whether to refresh subsample each time dataloader is called seed: int = 42 random seed smartcache_args: Optional[Dict] = None Arguments to pass to SmartcacheDataset target_columns: str = None column names in csv corresponding to ground truth types to alternate between during training grouping_column: str = None column names in csv corresponding to a factor that should be homogeneous across a batch dataloader_kwargs: Additional keyword arguments are passed to the torch.utils.data.DataLoader class when instantiating it (aside from `shuffle` which is only used for the train dataloader). Among these args are `num_workers`, `batch_size`, `shuffle`, etc. See the PyTorch docs for more info on these args: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader """ # make sure dataloader kwargs doesn't contain invalid arguments dataloader_kwargs.pop("drop_last", None) dataloader_kwargs.pop("batch_sampler", None) dataloader_kwargs.pop("sampler", None) super().__init__( path=path, transforms=transforms, split_column=split_column, columns=columns, split_map=split_map, just_inference=just_inference, cache_dir=cache_dir, subsample=subsample, refresh_subsample=refresh_subsample, seed=seed, smartcache_args=smartcache_args, **dataloader_kwargs, ) self.group_column = grouping_column self.target_columns = target_columns
[docs] def make_dataloader(self, split): kwargs = dict(**self.dataloader_kwargs) kwargs["shuffle"] = kwargs.get("shuffle", True) and split == "train" subset = self.get_dataset(split) batch_sampler = AlternatingBatchSampler( subset, batch_size=kwargs.pop("batch_size"), drop_last=True, shuffle=kwargs.pop("shuffle"), target_columns=self.target_columns, grouping_column=self.group_column, ) return DataLoader(dataset=subset, batch_sampler=batch_sampler, **kwargs)