Skip to content

Instantly share code, notes, and snippets.

@fmussari
Created July 28, 2023 13:41
Show Gist options
  • Save fmussari/f6fdb783ab99a37259bec182f864a988 to your computer and use it in GitHub Desktop.
Save fmussari/f6fdb783ab99a37259bec182f864a988 to your computer and use it in GitHub Desktop.
# Inspired by:
# https://github.com/microsoft/recommenders/blob/main/recommenders/datasets/python_splitters.py
# - In "recommenders" library, when spliting by 'item', `python_stratified_split` puts both rows of movies with only two ratings in the training set.
# - This version here puts one in the training and one in validation set.
import pandas as pd
import numpy as np
DEFAULT_USER_COL = 'user'
DEFAULT_ITEM_COL = 'title'
def split_dataframe_by_group(
data:pd.DataFrame,
ratio:float=0.75,
filter_by:str="user", # item or user
seed:int=42,
col_user:str=DEFAULT_USER_COL,
col_item:str=DEFAULT_ITEM_COL,
sort_column=None
):
"""Return split based on 'user' or 'item' column
"""
# A few preliminary checks.
if not (filter_by in ["user", "item"]):
raise ValueError("filter_by should be either 'user' or 'item'.")
if col_user not in data.columns:
raise ValueError("Schema of data not valid. Missing User Col")
if col_item not in data.columns:
raise ValueError("Schema of data not valid. Missing Item Col")
split_by_column = col_user if filter_by == "user" else col_item
ratio = [ratio, 1 - ratio]
df = data.copy()
if not sort_column:
np.random.seed(seed)
df["random"] = np.random.rand(df.shape[0])
order_by = "random"
else:
order_by = sort_column
df = df.sort_values([split_by_column, order_by])
groups = df.groupby(split_by_column)
df["count"] = groups[split_by_column].transform("count")
df["rank"] = groups.cumcount() + 1
splits = []
prev_threshold = None
for threshold in np.cumsum(ratio):
count_not_one = df['count'] != 1
condition = df["rank"] <= threshold * df["count"]
if prev_threshold is not None:
condition &= df["rank"] > prev_threshold * df["count"]
condition = count_not_one == condition # When 'count' equals 1, put the row in train set
prev_threshold = threshold
splits.append(list(df[condition].index))
return splits
def report_splits(
data:pd.DataFrame,
splits:list,
col_user:str=DEFAULT_USER_COL,
col_item:str=DEFAULT_ITEM_COL,
filter_by:str=DEFAULT_USER_COL
):
"""Report missing users/items in valid
"""
# number of filter_by columns with one value (is not going to be in valid)
count_reviews_by_feature = data.groupby(filter_by)[filter_by].transform('count')
print(f'Number of {filter_by} with one review: {sum(count_reviews_by_feature==1)}')
train_df = data.iloc[splits[0]]
valid_df = data.iloc[splits[1]]
print(f'Number of users in train: {len(train_df[col_user].unique())}')
print(f'Number of items in train: {len(train_df[col_item].unique())}')
print(f'Number of users in valid: {len(valid_df[col_user].unique())}')
print(f'Number of items in valid: {len(valid_df[col_item].unique())}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment