Source code for fnet.utils.split_dataset

import argparse
import os
import numpy as np
import pandas as pd


[docs]def int_or_float(x): try: val = int(x) assert val >= 0 except ValueError: val = float(x) assert 0.0 <= val <= 1.0 return val
[docs]def main(): parser = argparse.ArgumentParser() parser.add_argument("src_csv", help="path to dataset CSV") parser.add_argument("dst_dir", help="destination directory of dataset split") parser.add_argument( "--train_size", type=int_or_float, default=0.8, help="training set size as int or faction of total dataset size", ) parser.add_argument("--seed", type=int, default=42, help="random seed") parser.add_argument("--no_shuffle", action="store_true", help="random seed") parser.add_argument("-v", "--verbose", action="store_true", help="verbose") opts = parser.parse_args() vprint = print if opts.verbose else lambda *a, **kw: None name = os.path.basename(opts.src_csv).split(".")[0] path_store_split = os.path.join(opts.dst_dir, name) path_train_csv = os.path.join(path_store_split, "train.csv") path_test_csv = os.path.join(path_store_split, "test.csv") if os.path.exists(path_train_csv) and os.path.exists(path_test_csv): vprint("Using existing train/test split.") return rng = np.random.RandomState(opts.seed) df_all = pd.read_csv(opts.src_csv) if not opts.no_shuffle: df_all = df_all.sample(frac=1.0, random_state=rng).reset_index(drop=True) if opts.train_size == 0: df_test = df_all df_train = df_all[0:0] # empty DataFrame but with columns intact else: if isinstance(opts.train_size, int): idx_split = opts.train_size elif isinstance(opts.train_size, float): idx_split = round(len(df_all) * opts.train_size) else: raise AttributeError df_train = df_all[:idx_split] df_test = df_all[idx_split:] vprint("train/test sizes: {:d}/{:d}".format(len(df_train), len(df_test))) if not os.path.exists(path_store_split): os.makedirs(path_store_split) df_train.to_csv(path_train_csv, index=False) df_test.to_csv(path_test_csv, index=False) vprint("saved:", path_train_csv) vprint("saved:", path_test_csv)
if __name__ == "__main__": main()