Created
July 28, 2023 13:41
-
-
Save fmussari/f6fdb783ab99a37259bec182f864a988 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Inspired by: | |
# https://github.com/microsoft/recommenders/blob/main/recommenders/datasets/python_splitters.py | |
# - In "recommenders" library, when spliting by 'item', `python_stratified_split` puts both rows of movies with only two ratings in the training set. | |
# - This version here puts one in the training and one in validation set. | |
import pandas as pd | |
import numpy as np | |
DEFAULT_USER_COL = 'user' | |
DEFAULT_ITEM_COL = 'title' | |
def split_dataframe_by_group( | |
data:pd.DataFrame, | |
ratio:float=0.75, | |
filter_by:str="user", # item or user | |
seed:int=42, | |
col_user:str=DEFAULT_USER_COL, | |
col_item:str=DEFAULT_ITEM_COL, | |
sort_column=None | |
): | |
"""Return split based on 'user' or 'item' column | |
""" | |
# A few preliminary checks. | |
if not (filter_by in ["user", "item"]): | |
raise ValueError("filter_by should be either 'user' or 'item'.") | |
if col_user not in data.columns: | |
raise ValueError("Schema of data not valid. Missing User Col") | |
if col_item not in data.columns: | |
raise ValueError("Schema of data not valid. Missing Item Col") | |
split_by_column = col_user if filter_by == "user" else col_item | |
ratio = [ratio, 1 - ratio] | |
df = data.copy() | |
if not sort_column: | |
np.random.seed(seed) | |
df["random"] = np.random.rand(df.shape[0]) | |
order_by = "random" | |
else: | |
order_by = sort_column | |
df = df.sort_values([split_by_column, order_by]) | |
groups = df.groupby(split_by_column) | |
df["count"] = groups[split_by_column].transform("count") | |
df["rank"] = groups.cumcount() + 1 | |
splits = [] | |
prev_threshold = None | |
for threshold in np.cumsum(ratio): | |
count_not_one = df['count'] != 1 | |
condition = df["rank"] <= threshold * df["count"] | |
if prev_threshold is not None: | |
condition &= df["rank"] > prev_threshold * df["count"] | |
condition = count_not_one == condition # When 'count' equals 1, put the row in train set | |
prev_threshold = threshold | |
splits.append(list(df[condition].index)) | |
return splits | |
def report_splits( | |
data:pd.DataFrame, | |
splits:list, | |
col_user:str=DEFAULT_USER_COL, | |
col_item:str=DEFAULT_ITEM_COL, | |
filter_by:str=DEFAULT_USER_COL | |
): | |
"""Report missing users/items in valid | |
""" | |
# number of filter_by columns with one value (is not going to be in valid) | |
count_reviews_by_feature = data.groupby(filter_by)[filter_by].transform('count') | |
print(f'Number of {filter_by} with one review: {sum(count_reviews_by_feature==1)}') | |
train_df = data.iloc[splits[0]] | |
valid_df = data.iloc[splits[1]] | |
print(f'Number of users in train: {len(train_df[col_user].unique())}') | |
print(f'Number of items in train: {len(train_df[col_item].unique())}') | |
print(f'Number of users in valid: {len(valid_df[col_user].unique())}') | |
print(f'Number of items in valid: {len(valid_df[col_item].unique())}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment