Skip to content

Instantly share code, notes, and snippets.

@juftin
Last active April 8, 2022 15:00
Show Gist options
  • Save juftin/d102392cffebd51b497363cb787c7fd3 to your computer and use it in GitHub Desktop.
Save juftin/d102392cffebd51b497363cb787c7fd3 to your computer and use it in GitHub Desktop.
Reproducible Cohorting
"""
Reproducible Cohorting for Experiments
"""
import hashlib
import logging
from typing import Dict, Optional, Tuple
import numpy as np
from pandas import DataFrame
import pandas as pd
logger = logging.getLogger(__name__)
def hash_randomizer_single(x: str, salt: str) -> float:
"""
Assign a random number based on an input string.
Before converting to an MD5 hash, salt string with `extra`. Once the hash
is prepared take the final 8 characters and convert to int and divide by
16^8-1 (the output space of the the 8-character md5 hash).
What's returned is a randomly generated float that's always resolveable to its
original string value given a consistent salt.
Parameters
----------
x: str
String to Hash
salt: str
String to salt the Hash
Returns
-------
float
Float between 0 and 1 (>= 0, < 1)
"""
string_to_hash = x + salt
md5_hash = hashlib.md5(string_to_hash.encode("utf-8")).hexdigest()
right_8_hash = md5_hash[-8:]
numerator = int(right_8_hash, 16)
denominator = 16 ** 8 - 1
return numerator / denominator
def hash_randomizer(x: pd.Series, salt: Optional[str] = None) -> pd.Series:
"""
Assign a random number pd.Series on an input pd.Series.
Before hashing, salt string with `extra`.
Parameters
----------
x: pd.Series
ndarray to Hash
salt: Optional[str]
String to salt the Hash
Returns
-------
pd.Series
ndaray of floats between 0 and 1
"""
vectorizer = np.vectorize(hash_randomizer_single, otypes=["float"])
return vectorizer(x=x, salt=salt)
def validate_splits(splits: Dict[str, Tuple[float, float]]) -> bool:
"""
Function to validate dictionary of "splits"
Should be a dictionary with cohort values as the keys and tuples
with the inclusive lower and exclusive upper bound. Ranges for cohorts
should span 0 to 1 and be non-overlapping.
Parameters
----------
splits: Dict[str, Tuple[float, float]]
dictionary with cohorts as keys, tuple with range of hash values as values
Returns
-------
bool
Whether the split is valid
Examples
--------
>>> splits = {"Control": (0,0.99), "Treatment": (0.99,1)}
>>> validate_splits(splits)
True
>>> splits = {"Control": (0,0.98), "Treatment": (0.99,1)}
>>> validate_splits(splits)
False
>>> splits = {"Control": (0,0.98), "Treatment": (0.98,0.99)}
>>> validate_splits(splits)
False
>>> splits = {"Control": (0,0.99), "Treatment": (0.98,1)}
>>> validate_splits(splits)
False
>>> splits = {"Control": (0,0.98), "Treatment": (0.98,1.2)}
>>> validate_splits(splits)
False
"""
sorted_vals = sorted(splits.values(), key=lambda x: x[0])
if sorted_vals[0][0] != 0:
return False
elif sorted_vals[-1][1] != 1:
return False
for i in range(len(sorted_vals) - 1):
if sorted_vals[i + 1][0] != sorted_vals[i][1]:
return False
return True
def hash_to_cohort_single(hash_value: float, splits: Dict[str, Tuple[float, float]]) -> str:
"""
Assign a cohort based on the random hash value between 0 and 1
Parameters
----------
hash_value: float
Hashed cohort value. This is a float between 0 and 1
splits: Dict[str, Tuple[float, float]]
A Cohorting Split, such as dict(TREATMENT=(0.00, 0.85), CONTROL=(0.85, 1.00))
Returns
-------
str
The corresponding cohort name from the Splits Dictionary
Examples
--------
>>> splits = {"Control": (0,0.99), "Treatment": (0.99,1)}
>>> hash_to_cohort(0.99, splits)
array("Treatment", dtype="<U10")
>>> hash_to_cohort(0.98, splits)
array("Control", dtype="<U7")
"""
for cohort_name, cohort_splits in splits.items():
if hash_value >= cohort_splits[0] and hash_value < cohort_splits[1]:
return cohort_name
raise ValueError(f"Hash Value ({hash_value}) does not exist "
"within ranges in splits")
def hash_to_cohort(hash_value: pd.Series, splits: Dict[str, Tuple[float, float]]) -> pd.Series:
"""
Assign a cohort based on the random hash value between 0 and 1
Parameters
----------
hash_value: pd.Series
pd.Series of Hashed cohort values. These are floats between 0 and 1
splits: Dict[str, Tuple[float, float]]
A Cohorting Split, such as dict(TREATMENT=(0.00, 0.85), CONTROL=(0.85, 1.00))
Returns
-------
pd.Series
A series containing the corresponding cohort name from the Splits Dictionary
"""
assert validate_splits(splits=splits) is True
vectorizer = np.vectorize(hash_to_cohort_single, otypes=["str"])
return vectorizer(hash_value=hash_value, splits=splits)
def set_dataframe_cohorts(cohort_df: pd.DataFrame,
cohort_split: Dict[str, Tuple[float, float]],
salt: str,
unique_id: str,
cohort_column_name: str):
"""
Assign DataFrame rows to a Cohort
This transformation is done by hashing the unique ID and
assigning it to a plane on a cohort split dictionary
Parameters
----------
cohort_df: pd.DataFrame
Source Cohort DF
cohort_split: Dict[str, Tuple[float, float]]
A Cohorting Split, such as dict(TREATMENT=(0.00, 0.85), CONTROL=(0.85, 1.00))
salt: str
String Value to Salt the Hash with
unique_id: str
Column to use as Unique Identifier (will be hashed + extra) in the Data Set
cohort_column_name: str
Column name to return with Cohort Value, Will overwrite columns if name already exists
Returns
-------
DataFrame
Replica DataFrame returned (original cohort_df is left untouched)
"""
updated_cohort_df = cohort_df.copy()
updated_cohort_df["_hash_value"] = hash_randomizer(
x=updated_cohort_df[unique_id].astype(str),
salt=salt)
updated_cohort_df[cohort_column_name] = hash_to_cohort(
updated_cohort_df["_hash_value"].copy(),
cohort_split)
updated_cohort_df.drop(columns=["_hash_value"], inplace=True)
return updated_cohort_df
# Single Hashing Example
salt = "ExampleProjectName"
example_id = "12345abcde"
resulting_hash = hash_randomizer_single(x=example_id, salt=salt)
print(example_id, resulting_hash)
# DataFrame Hashing Example
sequential_user_ids = list(range(1, 100))
df = DataFrame(dict(user_id=sequential_user_ids))
df_copy = df.copy()
df["hash_value"] = hash_randomizer(df.user_id.astype(str), salt=salt)
print(df.sample(5))
# Single Cohorting Example
example_cohort_split = {"a": (0, .50), "b": (.50, 1)}
df["cohort"] = hash_to_cohort(hash_value=df.hash_value,
splits=example_cohort_split)
print(df.sample(5))
# Full Wrapper Function Example
df_cohorted = set_dataframe_cohorts(cohort_df=df_copy,
cohort_split=example_cohort_split,
salt=salt,
unique_id="user_id",
cohort_column_name="cohort")
print(df_cohorted.sample(5))
print(df_cohorted.groupby("cohort")["user_id"].count() * 100 / len(df_cohorted))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment