Source code for cyto_dl.datamodules.folder

from typing import Callable, Optional, Sequence, Union

from monai.data import DataLoader, Dataset, PersistentDataset
from monai.transforms import Compose
from omegaconf import ListConfig
from upath import UPath as Path

from cyto_dl.dataframe.transforms.filter import _filter_columns as filter_filenames


[docs]def make_folder_dataloader( path: Union[Path, str], transforms: Union[Sequence[Callable], Callable], cache_dir: Optional[Union[Path, str]] = None, regex: Optional[str] = None, startswith: Optional[str] = None, endswith: Optional[str] = None, contains: Optional[str] = None, excludes: Optional[str] = None, **dataloader_kwargs, ): """Create a dataloader based on a folder of samples. If no transforms are applied, each sample is a dictionary with a key "input" containing the corresponding path and a key "orig_fname" containing the original filename (with no extension). Files can be filtered out of the list with name-based rules, using `regex`, `startswith`, `endswith`, `contains`, `excludes`. Parameters ---------- path: Union[Path, str], Path to folder transforms: Union[Sequence[Callable], Callable], Transforms to apply to each sample cache_dir: Optional[Union[Path, str]] = None Path to a directory in which to store cached transformed inputs, to accelerate batch loading. regex: Optional[str] = None A string containing a regular expression to be matched startswith: Optional[str] = None A substring the matching columns must start with endswith: Optional[str] = None A substring the matching columns must end with contains: Optional[str] = None A substring the matching columns must contain excludes: Optional[str] = None A substring the matching columns must not contain dataloader_kwargs: Additional keyword arguments are passed to the torch.utils.data.DataLoader class when instantiating it (aside from `shuffle` which is only used for the train dataloader). Among these args are `num_workers`, `batch_size`, `shuffle`, etc. See the PyTorch docs for more info on these args: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader """ if isinstance(transforms, (list, tuple, ListConfig)): transforms = Compose(transforms) data = filter_filenames( list(map(str, Path(path).glob("*"))), regex, startswith, endswith, contains, excludes, ) data = [{"input": path, "orig_fname": Path(path).stem} for path in data] if cache_dir is not None: dataset = PersistentDataset(data, transform=transforms, cache_dir=cache_dir) else: dataset = Dataset(data, transform=transforms) return DataLoader(dataset, **dataloader_kwargs)