Created
January 25, 2017 01:45
-
-
Save jenkspt/b06e9258bfa1547c761dfb7c06c06ac1 to your computer and use it in GitHub Desktop.
Multivariate Testing with Beta Distribution
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy.stats as scs | |
import matplotlib.pyplot as plt | |
def plot(ax, x, y, label): | |
lines = ax.plot(x, y, label=label, lw=2) | |
ax.fill_between(x, 0, y, alpha=0.2, color=lines[0].get_c()) | |
if __name__ == '__main__': | |
# Simulated data for variations a and b | |
a_clicks, a_impressions = 10, 180 | |
b_clicks, b_impressions = 14, 140 | |
c_clicks, c_impressions = 0, 0 # Results in a uniform distribution | |
a_dist = scs.beta(1+a_clicks, 1+a_impressions) # Posterior distribution after 180 impressions | |
b_dist = scs.beta(1+b_clicks, 1+b_impressions) # Posterior distribution after 140 impressions | |
# The uniform distribution is the 'prior' distribution when there is no data | |
c_dist = scs.beta(1+c_clicks, 1+c_impressions) | |
sample_size = 10000 | |
a_samp, b_samp = a_dist.rvs(size=sample_size), b_dist.rvs(size=sample_size) | |
# Here you can answer important questions more definitifly than you can with frequentist methods | |
prob = np.sum(b_samp > a_samp) / float(sample_size) | |
print("Probability that variation b is better than variation a:", prob) | |
# Maybe there is cost to making a change, so you need to be better by a certain margin | |
prob = np.sum(b_samp > a_samp + .05) / float(sample_size) | |
print("Probability that variation b is more than 5% or better than variation a:", prob) | |
# This is also easy to extend to multiple variations | |
# Plot of the posterior distributions | |
x = np.arange(start=0, stop=1.001, step=0.001) # Beta distribution is defined on the interval [0,1] | |
y_a, y_b, y_c = a_dist.pdf(x), b_dist.pdf(x), c_dist.pdf(x) | |
fig, ax = plt.subplots(1,1) | |
plot(ax, x, y_a, 'Variation a') | |
plot(ax, x, y_b, 'Variation b') | |
plot(ax, x, y_c, 'Uniform Distribution') | |
ax.legend() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment