import argparse
import os
import numpy as np
import pandas as pd
[docs]def int_or_float(x):
try:
val = int(x)
assert val >= 0
except ValueError:
val = float(x)
assert 0.0 <= val <= 1.0
return val
[docs]def main():
parser = argparse.ArgumentParser()
parser.add_argument("src_csv", help="path to dataset CSV")
parser.add_argument("dst_dir", help="destination directory of dataset split")
parser.add_argument(
"--train_size",
type=int_or_float,
default=0.8,
help="training set size as int or faction of total dataset size",
)
parser.add_argument("--seed", type=int, default=42, help="random seed")
parser.add_argument("--no_shuffle", action="store_true", help="random seed")
parser.add_argument("-v", "--verbose", action="store_true", help="verbose")
opts = parser.parse_args()
vprint = print if opts.verbose else lambda *a, **kw: None
name = os.path.basename(opts.src_csv).split(".")[0]
path_store_split = os.path.join(opts.dst_dir, name)
path_train_csv = os.path.join(path_store_split, "train.csv")
path_test_csv = os.path.join(path_store_split, "test.csv")
if os.path.exists(path_train_csv) and os.path.exists(path_test_csv):
vprint("Using existing train/test split.")
return
rng = np.random.RandomState(opts.seed)
df_all = pd.read_csv(opts.src_csv)
if not opts.no_shuffle:
df_all = df_all.sample(frac=1.0, random_state=rng).reset_index(drop=True)
if opts.train_size == 0:
df_test = df_all
df_train = df_all[0:0] # empty DataFrame but with columns intact
else:
if isinstance(opts.train_size, int):
idx_split = opts.train_size
elif isinstance(opts.train_size, float):
idx_split = round(len(df_all) * opts.train_size)
else:
raise AttributeError
df_train = df_all[:idx_split]
df_test = df_all[idx_split:]
vprint("train/test sizes: {:d}/{:d}".format(len(df_train), len(df_test)))
if not os.path.exists(path_store_split):
os.makedirs(path_store_split)
df_train.to_csv(path_train_csv, index=False)
df_test.to_csv(path_test_csv, index=False)
vprint("saved:", path_train_csv)
vprint("saved:", path_test_csv)
if __name__ == "__main__":
main()