from typing import Optional
import pandas as pd
[docs]def split_dataframe(
dataframe: pd.DataFrame,
train_frac: float,
val_frac: Optional[float] = None,
return_splits: bool = True,
seed: int = 42,
):
"""Given a pandas dataframe, perform a train-val-test split and either return three different
dataframes, or append a column identifying the split each row belongs to.
TODO: extend this to enable balanced / stratified splitting
Parameters
----------
dataframe: pd.DataFrame
Input dataframe
train_frac: float
Fraction of data to use for training. Must be <= 1
val_frac: Optional[float]
Fraction of data to use for validation. By default,
the data not used for training is split in half
between validation and test
return_splits: bool = True
Whether to return the three splits separately, or to append
a column to the existing dataframe and return the modified
dataframe
seed: int = 42
Random seed for reproducibility
"""
# import here to optimize CLIs / Fire usage
from sklearn.model_selection import train_test_split
train_ix, val_test_ix = train_test_split(
dataframe.index.tolist(), train_size=train_frac, random_state=seed
)
if val_frac is not None:
val_frac = val_frac / (1 - train_frac)
else:
# by default use same size for val and test
val_frac = 0.5
val_ix, test_ix = train_test_split(val_test_ix, train_size=val_frac, random_state=seed)
if return_splits:
return dict(
train=dataframe.loc[train_ix],
valid=dataframe.loc[val_ix],
test=dataframe.loc[test_ix],
)
dataframe.loc[train_ix, "split"] = "train"
dataframe.loc[val_ix, "split"] = "valid"
dataframe.loc[test_ix, "split"] = "test"
return dataframe
[docs]def sample_n_each(
dataframe: pd.DataFrame,
column: str,
number: int = 1,
force: bool = False,
seed: int = 42,
):
"""Transform a dataframe to have equal number of rows per value of `column`.
In case a given value of `column` has less than `number` corresponding rows:
- if `force` is True the corresponding rows are sampled with replacement
- if `force` is False all the rows are given for that value
Parameters
----------
dataframe: pd.DataFrame
Input dataframe
column: str
The column to be used for selection
number: int
Number of rows to include per unique value of `column`
force: bool = False
Toggle upsampling of classes with number of samples smaller
than `number`
seed: int
Random seed used for sampling
"""
values = dataframe[column].unique()
subsets = []
for value in values:
class_rows = dataframe[dataframe[column] == value]
if force or (len(class_rows) >= number):
subsets.append(
class_rows.sample(
number,
random_state=seed,
# only sample with replacement if there
# aren't enough data points in this class
replace=(len(class_rows) < number),
)
)
else:
subsets.append(class_rows.sample(frac=1, random_state=seed))
return pd.concat(subsets)