Source code for cyto_dl.datamodules.data_dict

from typing import Callable, List, Optional, Sequence, Union

from monai.data import DataLoader, Dataset, PersistentDataset
from monai.transforms import Compose
from omegaconf import DictConfig, ListConfig, OmegaConf
from upath import UPath as Path


[docs]def make_data_dict_dataloader( data: Sequence[Union[DictConfig, dict]], transforms: Union[Sequence[Callable], Callable], cache_dir: Optional[Union[Path, str]] = None, **dataloader_kwargs, ): """Create a dataloader based on a dictionary of paths to images. Parameters ---------- data: Sequence[Union[DictConfig, dict]] A sequence of dictionaries, each containing a key expected by transforms (usually an image path) transforms: Union[Sequence[Callable], Callable], Transforms to apply to each sample cache_dir: Optional[Union[Path, str]] = None Path to a directory in which to store cached transformed inputs, to accelerate batch loading. dataloader_kwargs: Additional keyword arguments are passed to the torch.utils.data.DataLoader class when instantiating it (aside from `shuffle` which is only used for the train dataloader). Among these args are `num_workers`, `batch_size`, `shuffle`, etc. See the PyTorch docs for more info on these args: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader """ if isinstance(transforms, (list, tuple, ListConfig)): transforms = Compose(transforms) if isinstance(data, (ListConfig, DictConfig)): data = OmegaConf.to_container(data) if cache_dir is not None: dataset = PersistentDataset(data, transform=transforms, cache_dir=cache_dir) else: dataset = Dataset(data, transform=transforms) return DataLoader(dataset, **dataloader_kwargs)