import numpy as np
import pandas as pd
import torch
from aicsimageio import AICSImage
from fnet.data.fnetdataset import FnetDataset
[docs]class MultiChTiffDataset(FnetDataset):
"""
Dataset for multi-channel tiff files.
"""
def __init__(
self,
dataframe: pd.DataFrame = None,
path_csv: str = None,
transform_signal=None,
transform_target=None,
):
super().__init__(dataframe, path_csv, transform_signal, transform_target)
# if this column is a string assume it is in "[ind_1, ind_2, ..., ind_n]" format
if isinstance(self.df["channel_signal"][0], str):
self.df["channel_signal"] = [
np.fromstring(ch[1:-1], sep=", ").astype(int)
for ch in self.df["channel_signal"]
]
else:
self.df["channel_signal"] = [[int(ch)] for ch in self.df["channel_signal"]]
if isinstance(self.df["channel_target"][0], str):
self.df["channel_target"] = [
np.fromstring(ch[1:-1], sep=", ").astype(int)
for ch in self.df["channel_target"]
]
else:
self.df["channel_target"] = [[int(ch)] for ch in self.df["channel_target"]]
assert all(
i in self.df.columns
for i in ["path_tiff", "channel_signal", "channel_target"]
)
def __getitem__(self, index):
"""
Parameters
----------
index: integer
Returns
-------
C by <spatial dimensions> torch.Tensor
"""
element = self.df.iloc[index, :]
has_target = not np.any(np.isnan(element["channel_target"]))
# aicsimageio.imread loads as STCZYX, so we load only CZYX
with AICSImage(element["path_tiff"]) as img:
im_tmp = img.get_image_data("CZYX", S=0, T=0)
im_out = list()
im_out.append(im_tmp[element["channel_signal"]])
if has_target:
im_out.append(im_tmp[element["channel_target"]])
if self.transform_signal is not None:
for t in self.transform_signal:
im_out[0] = t(im_out[0])
if has_target and self.transform_target is not None:
for t in self.transform_target:
im_out[1] = t(im_out[1])
im_out = [torch.from_numpy(im.astype(float)).float() for im in im_out]
# unsqueeze to make the first dimension be the channel dimension
# im_out = [torch.unsqueeze(im, 0) for im in im_out]
return tuple(im_out)
def __len__(self):
return len(self.df)