import re
from itertools import chain, repeat
from typing import Iterator, List
import numpy as np
import pyarrow as pa
try:
import modin.pandas as pd
except ModuleNotFoundError:
import pandas as pd
import random
from typing import Sequence
from monai.data import Dataset, PersistentDataset, SmartCacheDataset
from monai.transforms import Compose, Transform
from omegaconf import DictConfig, ListConfig
from torch.utils.data import BatchSampler, Sampler, Subset, SubsetRandomSampler
from upath import UPath as Path
from cyto_dl.dataframe import read_dataframe
[docs]class RemoveNaNKeysd(Transform):
"""Transform to remove 'nan' keys from data dictionary.
When combined with adding `allow_missing_keys=True` to transforms and the alternating batch
sampler, this allows multi-task training when only one target is available at a time.
"""
def __init__(img):
super().__init__()
def __call__(self, img):
new_data = {k: v for k, v in img.items() if not pd.isna(v)}
return new_data
# randomly switch between generating batches from each sampler
[docs]class AlternatingBatchSampler(BatchSampler):
"""Subclass of pytorch's `BatchSampler` that alternates between sampling from mutually
exclusive columns of a dataframe dataset."""
def __init__(
self,
subset: Subset,
target_columns: Sequence[str] = None,
grouping_column: Sequence[str] = None,
batch_size: int = 1,
drop_last: bool = False,
shuffle: bool = False,
sampler: Sampler = SubsetRandomSampler,
):
"""
Parameters
----------
subset: Subset
Subset of monai dataset wrapping a dataframe
target_columns: Sequence[str]
names of columns in `subset` dataframe representing types of ground truth images to alternate between
batch_size:int
Size of batch
drop_last:bool=False
Whether to drop last incomplete batch
shuffle:bool=False
Whether to randomly select between columns in `target_columns`. If False, batches will follow the order of
`target_columns`
sampler:Sampler=SubsetRandomSampler
Sampler to sample from each column in `target_columns`
"""
assert (
grouping_column is not None or target_columns is not None
), "Either target column or grouping column must be provided"
super().__init__(None, batch_size, drop_last)
# pytorch subsets consist of the original dataset plus a list of `subset_indices`. when provided
# an index `i` in getitem, subsets return `subset_indices[i]`. Since we are indexing into the
# subset indices instead of the indices of the original dataframe, we have to reset the index
# of the subsetted dataframe
# order is subset.monai_dataset.dataframewrapper.dataframe
idx_to_take = list(subset.indices)
subset_df = subset.dataset.data.dataframe.take(idx_to_take).to_pandas().reset_index()
samplers = []
if target_columns is not None:
for name in target_columns:
# returns an index into dataset.indices where head column is not empty
head_indices = subset_df.index[~subset_df[name].isna()].values
if len(head_indices) == 0:
raise ValueError(
f"Dataset must contain examples of head {name}. Please increase the value of subsample."
)
samplers.append(sampler(head_indices))
else:
grouping_options = set(
subset.dataset.data.dataframe[grouping_column].unique().to_pylist()
)
seen_keys = set()
for group_key, group in subset_df.groupby(grouping_column):
samplers.append(sampler(group.index.values))
seen_keys.add(group_key)
unseen_keys = grouping_options - seen_keys
if unseen_keys:
raise ValueError(
f"Dataset must contain examples of groups {unseen_keys}."
"Please increase the value of subsample."
)
self.samplers = samplers
self.shuffle = shuffle
self._sampler_generator()
def _sampler_generator(self):
self.sampler_iterators = [iter(s) for s in self.samplers]
# for now include equal numbers of all samplers
# worst case shuffle, we get [0,0,0,0,...., 1, 1, 1, 1]. after the 0s are done, stop iteration
# is raised but we're only halfway through the epoch! this prevents that by never fully exhausting
# the samplers
samples_per_sampler = len(self) // len(self.samplers)
if samples_per_sampler <= 0:
raise ValueError("Insufficient examples per task head. Please decrease batch size.")
interleaved_sampler_order = repeat(range(len(self.samplers)), samples_per_sampler)
interleaved_sampler_order = chain.from_iterable(interleaved_sampler_order)
interleaved_sampler_order = list(interleaved_sampler_order)
if self.shuffle:
random.shuffle(interleaved_sampler_order)
self.sampler_order = interleaved_sampler_order
def __iter__(self) -> Iterator[List[int]]:
for sampler_ix in self.sampler_order:
try:
yield [next(self.sampler_iterators[sampler_ix]) for _ in range(self.batch_size)]
except StopIteration:
self._sampler_generator()
yield [next(self.sampler_iterators[sampler_ix]) for _ in range(self.batch_size)]
def __len__(self) -> int:
min_num_samples = min(len(sampler) for sampler in self.samplers)
min_num_batches = min_num_samples // self.batch_size
return min_num_batches * len(self.samplers)
[docs]def get_canonical_split_name(split):
for canon in ("train", "val", "test", "predict"):
if split.startswith(canon) or canon.startswith(split):
return canon
return None
[docs]def get_dataset(dataframe, transform, split, cache_dir=None, smartcache_args=None):
data = _DataframeWrapper(dataframe)
if smartcache_args is not None and split in ("train", "val"):
cache_rate = smartcache_args.get("cache_rate", 0.1)
if cache_rate * len(data) >= 1:
print(f"Initializing SmartCacheDataset for {split}")
return SmartCacheDataset(
data,
transform=transform,
cache_rate=cache_rate,
replace_rate=smartcache_args.get("replace_rate", 0.1),
num_init_workers=smartcache_args.get("num_init_workers", 4),
num_replace_workers=smartcache_args.get("num_replace_workers", 4),
)
else:
print(
f"Cache rate {cache_rate} too low for {split} dataset with size {len(data)}, using Dataset."
)
if cache_dir is not None and split in ("train", "val"):
print(f"Initializing PersistentDataset for {split} at {cache_dir}")
return PersistentDataset(data, transform=transform, cache_dir=cache_dir)
return Dataset(data, transform=transform)
[docs]def make_single_dataframe_splits(
dataframe_path,
transforms,
split_column,
columns=None,
just_inference=False,
split_map=None,
cache_dir=None,
smartcache_args=None,
):
dataframe = read_dataframe(dataframe_path, columns)
dataframe[split_column] = dataframe[split_column].astype(np.dtype("O"))
if not just_inference:
assert dataframe.dtypes[split_column] == np.dtype("O")
if split_map is not None:
dataframe[split_column] = dataframe[split_column].replace(split_map)
split_names = dataframe[split_column].unique().tolist()
if not just_inference:
assert set(split_names).issubset(
{"train", "training", "valid", "val", "validation", "test", "testing"}
)
if split_column != "split":
dataframe["split"] = dataframe[split_column].apply(get_canonical_split_name)
datasets = {}
if not just_inference:
for split in ("train", "val", "test"):
if cache_dir is not None:
_split_cache = Path(cache_dir) / split
_split_cache.mkdir(exist_ok=True, parents=True)
else:
_split_cache = None
datasets[split] = get_dataset(
dataframe.loc[dataframe["split"].str.startswith(split)],
transforms[split],
split,
_split_cache,
smartcache_args=smartcache_args,
)
datasets["predict"] = get_dataset(
dataframe,
transform=transforms["predict"],
split="predict",
cache_dir=cache_dir,
smartcache_args=smartcache_args,
)
return datasets
[docs]def make_multiple_dataframe_splits(
split_path,
transforms,
columns=None,
just_inference=False,
cache_dir=None,
smartcache_args=None,
):
split_path = Path(split_path)
datasets = {}
predict_df = []
for fpath in chain(split_path.glob("*.csv"), split_path.glob("*.parquet")):
split = re.findall(r"(.*)\.(?:csv|parquet)", fpath.name)[0]
split = get_canonical_split_name(split)
if split is None:
print(f"Skipping {fpath} as it does not match any known split name.")
continue
dataframe = read_dataframe(fpath, required_columns=columns)
dataframe["split"] = split
if cache_dir is not None:
_split_cache = Path(cache_dir) / split
_split_cache.mkdir(exist_ok=True, parents=True)
else:
_split_cache = None
if not just_inference:
datasets[split] = get_dataset(
dataframe, transforms[split], split, _split_cache, smartcache_args=smartcache_args
)
predict_df.append(dataframe.copy())
if len(predict_df) == 0:
raise ValueError(f"No valid dataframe files found in {split_path}")
predict_df = pd.concat(predict_df)
datasets["predict"] = get_dataset(
predict_df,
transform=transforms["predict"],
split="predict",
cache_dir=cache_dir,
smartcache_args=smartcache_args,
)
return datasets
class _DataframeWrapper:
"""Class to wrap a pandas DataFrame in a pytorch Dataset. In practice, at AICS we use this to
wrap manifest dataframes that point to the image files that correspond to a cell. The `loaders`
dict contains a loading function for each key, normally consisting of a function to load the
contents of a file from a path.
Parameters
----------
dataframe: pd.DataFrame
The file which points to or contains the data to be loaded
"""
def __init__(self, dataframe):
self.dataframe = pa.Table.from_pandas(dataframe)
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
return {k: v.pop() for k, v in self.dataframe.take([idx]).to_pydict().items()}