Skip to content

Instantly share code, notes, and snippets.

@djsegal
Last active January 2, 2020 07:40
Show Gist options
  • Save djsegal/d8c9e02fa75490e2686c4a866438cad7 to your computer and use it in GitHub Desktop.
Save djsegal/d8c9e02fa75490e2686c4a866438cad7 to your computer and use it in GitHub Desktop.
import pandas as pd
from sklearn.model_selection import train_test_split
def custom_train_test_split(cur_data, random_state=42, cur_target="price", cur_boolean="is_brooklyn"):
if cur_boolean in cur_data.columns:
cur_train_1, cur_test_1 = _custom_train_test_split(
cur_data[cur_data[cur_boolean]], random_state, cur_target
)
cur_train_2, cur_test_2 = _custom_train_test_split(
cur_data[~cur_data[cur_boolean]], random_state, cur_target
)
cur_train = pd.concat([cur_train_1, cur_train_2])
cur_test = pd.concat([cur_test_1, cur_test_2])
else:
cur_train, cur_test = _custom_train_test_split(
cur_data, random_state, cur_target
)
cur_train = cur_train.sample(frac=1, random_state=random_state)
cur_test = cur_test.sample(frac=1, random_state=random_state)
X_train = cur_train.drop(columns=["id", cur_target])
X_test = cur_test.drop(columns=["id", cur_target])
y_train = cur_train[cur_target]
y_test = cur_test[cur_target]
return X_train, X_test, y_train, y_test
def _custom_train_test_split(cur_data, random_state, cur_target):
bin_count = 10
binned_ids = pd.qcut(
cur_data.groupby("id")["id"].count(),
bin_count, labels=False
)
def _train_test_lambda(cur_value):
sub_bin_ids = binned_ids[binned_ids == cur_value]
sampled_ids = sub_bin_ids.sample(
frac=0.1, random_state=random_state
)
return sampled_ids.index
cur_map = map(_train_test_lambda, range(bin_count))
test_ids = [
cur_item for cur_list in cur_map for cur_item in cur_list
]
cur_test_1 = cur_data[cur_data.id.isin(test_ids)]
cur_train = cur_data[~cur_data.id.isin(test_ids)]
cur_stratify = pd.qcut(cur_train[cur_target], bin_count, labels=False)
cur_count = int(round( 0.2 * len(cur_data) - len(cur_test_1) ))
cur_train, cur_test_2 = train_test_split(
cur_train,
test_size=cur_count,
stratify=cur_stratify,
random_state=random_state
)
cur_test = pd.concat([cur_test_1, cur_test_2])
return cur_train, cur_test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment