Last active
April 8, 2022 15:00
-
-
Save juftin/d102392cffebd51b497363cb787c7fd3 to your computer and use it in GitHub Desktop.
Reproducible Cohorting
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Reproducible Cohorting for Experiments | |
""" | |
import hashlib | |
import logging | |
from typing import Dict, Optional, Tuple | |
import numpy as np | |
from pandas import DataFrame | |
import pandas as pd | |
logger = logging.getLogger(__name__) | |
def hash_randomizer_single(x: str, salt: str) -> float: | |
""" | |
Assign a random number based on an input string. | |
Before converting to an MD5 hash, salt string with `extra`. Once the hash | |
is prepared take the final 8 characters and convert to int and divide by | |
16^8-1 (the output space of the the 8-character md5 hash). | |
What's returned is a randomly generated float that's always resolveable to its | |
original string value given a consistent salt. | |
Parameters | |
---------- | |
x: str | |
String to Hash | |
salt: str | |
String to salt the Hash | |
Returns | |
------- | |
float | |
Float between 0 and 1 (>= 0, < 1) | |
""" | |
string_to_hash = x + salt | |
md5_hash = hashlib.md5(string_to_hash.encode("utf-8")).hexdigest() | |
right_8_hash = md5_hash[-8:] | |
numerator = int(right_8_hash, 16) | |
denominator = 16 ** 8 - 1 | |
return numerator / denominator | |
def hash_randomizer(x: pd.Series, salt: Optional[str] = None) -> pd.Series: | |
""" | |
Assign a random number pd.Series on an input pd.Series. | |
Before hashing, salt string with `extra`. | |
Parameters | |
---------- | |
x: pd.Series | |
ndarray to Hash | |
salt: Optional[str] | |
String to salt the Hash | |
Returns | |
------- | |
pd.Series | |
ndaray of floats between 0 and 1 | |
""" | |
vectorizer = np.vectorize(hash_randomizer_single, otypes=["float"]) | |
return vectorizer(x=x, salt=salt) | |
def validate_splits(splits: Dict[str, Tuple[float, float]]) -> bool: | |
""" | |
Function to validate dictionary of "splits" | |
Should be a dictionary with cohort values as the keys and tuples | |
with the inclusive lower and exclusive upper bound. Ranges for cohorts | |
should span 0 to 1 and be non-overlapping. | |
Parameters | |
---------- | |
splits: Dict[str, Tuple[float, float]] | |
dictionary with cohorts as keys, tuple with range of hash values as values | |
Returns | |
------- | |
bool | |
Whether the split is valid | |
Examples | |
-------- | |
>>> splits = {"Control": (0,0.99), "Treatment": (0.99,1)} | |
>>> validate_splits(splits) | |
True | |
>>> splits = {"Control": (0,0.98), "Treatment": (0.99,1)} | |
>>> validate_splits(splits) | |
False | |
>>> splits = {"Control": (0,0.98), "Treatment": (0.98,0.99)} | |
>>> validate_splits(splits) | |
False | |
>>> splits = {"Control": (0,0.99), "Treatment": (0.98,1)} | |
>>> validate_splits(splits) | |
False | |
>>> splits = {"Control": (0,0.98), "Treatment": (0.98,1.2)} | |
>>> validate_splits(splits) | |
False | |
""" | |
sorted_vals = sorted(splits.values(), key=lambda x: x[0]) | |
if sorted_vals[0][0] != 0: | |
return False | |
elif sorted_vals[-1][1] != 1: | |
return False | |
for i in range(len(sorted_vals) - 1): | |
if sorted_vals[i + 1][0] != sorted_vals[i][1]: | |
return False | |
return True | |
def hash_to_cohort_single(hash_value: float, splits: Dict[str, Tuple[float, float]]) -> str: | |
""" | |
Assign a cohort based on the random hash value between 0 and 1 | |
Parameters | |
---------- | |
hash_value: float | |
Hashed cohort value. This is a float between 0 and 1 | |
splits: Dict[str, Tuple[float, float]] | |
A Cohorting Split, such as dict(TREATMENT=(0.00, 0.85), CONTROL=(0.85, 1.00)) | |
Returns | |
------- | |
str | |
The corresponding cohort name from the Splits Dictionary | |
Examples | |
-------- | |
>>> splits = {"Control": (0,0.99), "Treatment": (0.99,1)} | |
>>> hash_to_cohort(0.99, splits) | |
array("Treatment", dtype="<U10") | |
>>> hash_to_cohort(0.98, splits) | |
array("Control", dtype="<U7") | |
""" | |
for cohort_name, cohort_splits in splits.items(): | |
if hash_value >= cohort_splits[0] and hash_value < cohort_splits[1]: | |
return cohort_name | |
raise ValueError(f"Hash Value ({hash_value}) does not exist " | |
"within ranges in splits") | |
def hash_to_cohort(hash_value: pd.Series, splits: Dict[str, Tuple[float, float]]) -> pd.Series: | |
""" | |
Assign a cohort based on the random hash value between 0 and 1 | |
Parameters | |
---------- | |
hash_value: pd.Series | |
pd.Series of Hashed cohort values. These are floats between 0 and 1 | |
splits: Dict[str, Tuple[float, float]] | |
A Cohorting Split, such as dict(TREATMENT=(0.00, 0.85), CONTROL=(0.85, 1.00)) | |
Returns | |
------- | |
pd.Series | |
A series containing the corresponding cohort name from the Splits Dictionary | |
""" | |
assert validate_splits(splits=splits) is True | |
vectorizer = np.vectorize(hash_to_cohort_single, otypes=["str"]) | |
return vectorizer(hash_value=hash_value, splits=splits) | |
def set_dataframe_cohorts(cohort_df: pd.DataFrame, | |
cohort_split: Dict[str, Tuple[float, float]], | |
salt: str, | |
unique_id: str, | |
cohort_column_name: str): | |
""" | |
Assign DataFrame rows to a Cohort | |
This transformation is done by hashing the unique ID and | |
assigning it to a plane on a cohort split dictionary | |
Parameters | |
---------- | |
cohort_df: pd.DataFrame | |
Source Cohort DF | |
cohort_split: Dict[str, Tuple[float, float]] | |
A Cohorting Split, such as dict(TREATMENT=(0.00, 0.85), CONTROL=(0.85, 1.00)) | |
salt: str | |
String Value to Salt the Hash with | |
unique_id: str | |
Column to use as Unique Identifier (will be hashed + extra) in the Data Set | |
cohort_column_name: str | |
Column name to return with Cohort Value, Will overwrite columns if name already exists | |
Returns | |
------- | |
DataFrame | |
Replica DataFrame returned (original cohort_df is left untouched) | |
""" | |
updated_cohort_df = cohort_df.copy() | |
updated_cohort_df["_hash_value"] = hash_randomizer( | |
x=updated_cohort_df[unique_id].astype(str), | |
salt=salt) | |
updated_cohort_df[cohort_column_name] = hash_to_cohort( | |
updated_cohort_df["_hash_value"].copy(), | |
cohort_split) | |
updated_cohort_df.drop(columns=["_hash_value"], inplace=True) | |
return updated_cohort_df | |
# Single Hashing Example | |
salt = "ExampleProjectName" | |
example_id = "12345abcde" | |
resulting_hash = hash_randomizer_single(x=example_id, salt=salt) | |
print(example_id, resulting_hash) | |
# DataFrame Hashing Example | |
sequential_user_ids = list(range(1, 100)) | |
df = DataFrame(dict(user_id=sequential_user_ids)) | |
df_copy = df.copy() | |
df["hash_value"] = hash_randomizer(df.user_id.astype(str), salt=salt) | |
print(df.sample(5)) | |
# Single Cohorting Example | |
example_cohort_split = {"a": (0, .50), "b": (.50, 1)} | |
df["cohort"] = hash_to_cohort(hash_value=df.hash_value, | |
splits=example_cohort_split) | |
print(df.sample(5)) | |
# Full Wrapper Function Example | |
df_cohorted = set_dataframe_cohorts(cohort_df=df_copy, | |
cohort_split=example_cohort_split, | |
salt=salt, | |
unique_id="user_id", | |
cohort_column_name="cohort") | |
print(df_cohorted.sample(5)) | |
print(df_cohorted.groupby("cohort")["user_id"].count() * 100 / len(df_cohorted)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment