from fnet.utils.general_utils import to_objects, whats_my_name
from typing import List, Optional, Union
import pandas as pd
import torch.utils.data
def _to_str_list(olist: List) -> Optional[List[str]]:
"""Turns a list of objects into a list of the objects' string
representations.
"""
if olist is None:
return None
return [whats_my_name(o) for o in olist]
class _LocIndexer:
"""'Loc' indexer of objects with a 'df' (DataFrame) attribute."""
def __init__(self, super_obj):
assert isinstance(super_obj.df, pd.DataFrame)
self.super_obj = super_obj
def __getitem__(self, idx):
idx_trans = self.super_obj.df.index.get_loc(idx)
return self.super_obj[idx_trans]
class _iLocIndexer:
"""'iLoc' indexer of objects with a 'df' (DataFrame) attribute."""
def __init__(self, super_obj):
assert isinstance(super_obj.df, pd.DataFrame)
self.super_obj = super_obj
def __getitem__(self, idx):
return self.super_obj[idx]
[docs]class FnetDataset(torch.utils.data.Dataset):
"""Abstract class for fnet datasets.
Parameters
----------
dataframe
DataFrame where rows are dataset elements. Overrides path_csv.
path_csv
Path to csv from which to create DataFrame.
transform_signal
List of transforms to apply to signal image.
transform_target
List of transforms to apply to target image.
"""
def __init__(
self,
dataframe: Optional[pd.DataFrame] = None,
path_csv: Optional[str] = None,
transform_signal: Optional[list] = None,
transform_target: Optional[list] = None,
):
self.path_csv = None
if dataframe is not None:
self.df = dataframe
else:
self.path_csv = path_csv
self.df = pd.read_csv(self.path_csv)
self.transform_signal = to_objects(transform_signal)
self.transform_target = to_objects(transform_target)
self._metadata = None
self.loc = _LocIndexer(self)
self.iloc = _iLocIndexer(self)
@property
def metadata(self) -> dict:
"""Returns metadata about the dataset."""
if self._metadata is not None:
return self._metadata
self._metadata = {}
if self.path_csv is not None:
self._metadata["path_csv"] = self.path_csv
self._metadata["transform_signal"] = _to_str_list(self.transform_signal)
self._metadata["transform_target"] = _to_str_list(self.transform_target)
return self._metadata