import os
import numpy as np
import pandas as pd
import tifffile
import torch
from fnet.data import TiffDataset
from fnet.utils.general_utils import add_augmentations
[docs]def DummyFnetDataset(train: bool = False) -> TiffDataset:
    """Returns a dummy Fnetdataset."""
    df = pd.DataFrame(
        {
            "path_signal": [os.path.join("data", "EM_low.tif")],
            "path_target": [os.path.join("data", "MBP_low.tif")],
        }
    ).rename_axis("arbitrary")
    if not train:
        df = add_augmentations(df)
    return TiffDataset(dataframe=df) 
class _CustomDataset:
    """Custom, non-FnetDataset."""
    def __init__(self, df: pd.DataFrame):
        self._df = df
    def __len__(self):
        return len(self._df)
    def __getitem__(self, idx):
        loc = self._df.index[idx]
        sig = torch.from_numpy(
            tifffile.imread(self._df.loc[loc, "path_signal"])[np.newaxis,]
        )
        tar = torch.from_numpy(
            tifffile.imread(self._df.loc[loc, "path_target"])[np.newaxis,]
        )
        return (sig, tar)
[docs]def DummyCustomFnetDataset(train: bool = False) -> TiffDataset:
    """Returns a dummy custom dataset."""
    df = pd.DataFrame(
        {
            "path_signal": [os.path.join("data", "EM_low.tif")],
            "path_target": [os.path.join("data", "MBP_low.tif")],
        }
    )
    if not train:
        df = add_augmentations(df)
    return _CustomDataset(df)