from itertools import chain
from pathlib import Path
from typing import Optional, Union
import dask
import numpy as np
import pandas as pd
from bioio import BioImage
from dask.diagnostics import ProgressBar
from lightning import LightningDataModule
from monai.data import DataLoader
from monai.data.dataset import CacheDataset, Dataset, SmartCacheDataset
from monai.transforms import Compose
from sklearn.model_selection import train_test_split
[docs]class SmartcacheDatamodule(LightningDataModule):
"""Datamodule for large CZI datasets that don't fit in memory."""
def __init__(
self,
csv_path: Optional[Union[Path, str]] = None,
transforms: Compose = None,
img_data: Optional[Union[Path, str]] = None,
n_val: int = 20,
pct_val: float = 0.1,
img_path_column: str = "raw",
channel_column: str = "ch",
spatial_dims: int = 3,
num_neighbors: int = 0,
num_workers: int = 4,
cache_rate: float = 0.5,
replace_rate: float = 0.1,
**kwargs,
):
"""
Parameters
----------
csv_path: Union[Path, str]
path to csv with image in `img_path_column` and channel in `channel_column`
transforms: Compose
Monai transforms to apply to each image. Should start with a transform that uses bioio for image reading
img_data: Union[Path, str]
csv_path generated by get_per_file_args that enumerates scenes and timepoints for each image in csv_path
n_val: int
number of validation images to use. Minimum of pct_val * n_images and n_val is used.
pct_val: float
percentage of images to use for validation. Minimum of pct_val * n_images and n_val is used.
img_path_column: str
column in csv_path that contains the path to the image
channel_column: str
column in csv_path that contains the channel to use
spatial_dims: int
number of spatial dimensions in the image
num_neighbors: int
number of neighboring timepoints to use
num_workers: int
number of workers to use for loading data. Most be specified here to schedule replacement workers for cache data
cache_rate: float
percentage of data to cache
replace_rate: float
percentage of data to replace
kwargs:
additional arguments to pass to DataLoader
"""
super().__init__()
self.img_data = {}
if isinstance(img_data, (str, Path)):
# read img_data if it's a path, otherwise set to empty dict
self.img_data["train"] = [
row._asdict()
for row in pd.read_csv(Path(img_data) / "train_img_data.csv").itertuples()
]
self.img_data["val"] = [
row._asdict()
for row in pd.read_csv(Path(img_data) / "val_img_data.csv").itertuples()
]
elif csv_path is not None:
self.csv_path = Path(csv_path)
(self.csv_path.parents[0] / "loaded_data").mkdir(exist_ok=True, parents=True)
self.df = pd.read_csv(csv_path)
else:
raise ValueError("csv_path or img_data must be specified")
self.num_workers = num_workers
self.kwargs = kwargs
self.n_val = n_val
self.pct_val = pct_val
self.datasets = {}
self.img_path_column = img_path_column
self.channel_column = channel_column
self.spatial_dims = spatial_dims
self.transforms = transforms
self.num_neighbors = num_neighbors
self.cache_rate = cache_rate
self.replace_rate = replace_rate
def _get_scenes(self, img):
"""Get the number of scenes in an image."""
return img.scenes
def _get_timepoints(self, img):
"""Get the number of timepoints in an image."""
timepoints = list(range(img.dims.T))
if self.num_neighbors > 0:
return timepoints[: -self.num_neighbors]
return timepoints
@dask.delayed
def _get_file_args(self, row):
row = row._asdict()
img = BioImage(row[self.img_path_column])
scenes = self._get_scenes(img)
timepoints = self._get_timepoints(img)
img_data = []
use_neighbors = self.num_neighbors > 0
for scene in scenes:
for timepoint in timepoints:
img_data.append(
{
"dimension_order_out": (
"ZYX"[-self.spatial_dims :]
if not use_neighbors
else "T" + "ZYX"[-self.spatial_dims :]
),
"C": row[self.channel_column],
"scene": scene,
"T": (
timepoint
if not use_neighbors
else [timepoint + i for i in range(self.num_neighbors + 1)]
),
"original_path": row[self.img_path_column],
}
)
return img_data
[docs] def get_per_file_args(self, df):
"""Parallelize getting the image loading arguments enumerating all
timepoints/channels/scenes for each file in the dataframe."""
with ProgressBar():
img_data = dask.compute(*[self._get_file_args(row) for row in df.itertuples()])
img_data = list(chain.from_iterable(img_data))
return img_data
[docs] def prepare_data(self):
pass
[docs] def setup(self, stage=None):
if stage == "fit":
if "train" not in self.img_data or "val" not in self.img_data:
# update img_data
image_data = self.get_per_file_args(self.df)
val_size = np.min([self.n_val, int(len(image_data) * self.pct_val)])
val_size = np.max([val_size, 1])
self.img_data["train"], self.img_data["val"] = train_test_split(
image_data, test_size=val_size
)
print("Train images:", len(self.img_data["train"]))
print("Val images:", len(self.img_data["val"]))
pd.DataFrame(self.img_data["train"]).to_csv(
f"{self.csv_path.parents[0]}/loaded_data/train_img_data.csv",
index=False,
)
pd.DataFrame(self.img_data["val"]).to_csv(
f"{self.csv_path.parents[0]}/loaded_data/val_img_data.csv", index=False
)
self.datasets["train"] = SmartCacheDataset(
self.img_data["train"],
transform=self.transforms["train"],
cache_rate=self.cache_rate,
num_replace_workers=self.num_workers,
num_init_workers=self.num_workers,
replace_rate=self.replace_rate,
)
self.datasets["val"] = CacheDataset(
self.img_data["val"],
transform=self.transforms["valid"],
cache_rate=1.0,
num_workers=self.num_workers,
)
elif stage in ("test", "predict"):
self.img_data[stage] = self.get_per_file_args(self.df)
self.datasets[stage] = Dataset(self.img_data[stage], transform=self.transforms[stage])
[docs] def make_dataloader(self, split):
# smartcachedataset can't have persistent workers
self.kwargs["persistent_workers"] = split not in ("train", "val")
if "num_workers" in self.kwargs:
del self.kwargs["num_workers"]
return DataLoader(
self.datasets[split],
num_workers=self.num_workers,
**self.kwargs,
)
[docs] def train_dataloader(self):
return self.make_dataloader("train")
[docs] def val_dataloader(self):
return self.make_dataloader("val")
[docs] def test_dataloader(self):
return self.make_dataloader("test")
[docs] def predict_dataloader(self):
return self.make_dataloader("predict")