Created
October 1, 2018 23:04
-
-
Save jhurliman/e3c7d6a3f0b4382b186430f27d8e1345 to your computer and use it in GitHub Desktop.
An implementation of "BEST: Bayesian Estimation Supersedes the t Test" using pymc3
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from multiprocessing import cpu_count | |
import matplotlib | |
matplotlib.use('Agg', warn=False) | |
import matplotlib.pyplot as plt # noqa: E402 | |
import numpy as np # noqa: E402 | |
import pymc3 as pm # noqa: E402 | |
import six.moves | |
# Region of practical equivalence, in milliseconds | |
ROPE = [-2, 2] | |
# An implementation of "BEST: Bayesian Estimation Supersedes the t Test" using | |
# pymc3. See <http://www.indiana.edu/~kruschke/BEST/> | |
def ab_test(y1, y2): | |
y = np.concatenate([y1, y2]) | |
stdev = y.std() | |
# Prior parameters for mean | |
# mu = mu of [y1, y2], sigma = twice the sigma of [y1, y2] | |
mu_m = y.mean() | |
mu_s = stdev * 2 | |
# Prior parameters for standard deviation | |
# Uniform distribution two magnitudes above and below sigma of [y1, y2] | |
sigma_low = stdev / 100 | |
sigma_high = stdev * 100 | |
with pm.Model(): | |
a_mean = pm.Normal('a_mean', mu_m, sd=mu_s) | |
b_mean = pm.Normal('b_mean', mu_m, sd=mu_s) | |
a_std = pm.Uniform('a_std', lower=sigma_low, upper=sigma_high) | |
b_std = pm.Uniform('b_std', lower=sigma_low, upper=sigma_high) | |
nu = pm.Exponential('nu_minus_one', 1 / 29.) + 1 | |
# Use 1/(sigma^2) to work with PyMC3's parameterization of Student's | |
# t-distribution | |
pm.StudentT('a', nu=nu, mu=a_mean, lam=a_std**-2, observed=y1) | |
pm.StudentT('b', nu=nu, mu=b_mean, lam=b_std**-2, observed=y2) | |
pm.Deterministic('diff of means', b_mean - a_mean) | |
pm.Deterministic('diff of stds', b_std - a_std) | |
trace = pm.sample(6000, tune=1000, njobs=1, progressbar=False, | |
random_seed=list(six.moves.range(cpu_count()))) | |
return trace | |
def plot_ab_test(name, trace, git_commit, prev_git_commit): | |
plt.subplots(figsize=(14, 5)) | |
pm.forestplot(trace, rhat=False, varnames=['a_mean', 'b_mean', 'a_std', 'b_std'], | |
colors='#ff8c00') | |
rhat = pm.diagnostics.gelman_rubin(trace, varnames=['a_mean', 'b_mean', 'a_std', 'b_std']) | |
fig = plt.gcf() | |
ax = plt.gca() | |
# Checking lack of convergence based on: Brooks, S. P., and A. Gelman. 1997. | |
# General Methods for Monitoring Convergence of Iterative Simulations. | |
# Journal of Computational and Graphical Statistics 7: 434-455. | |
warning_str = '' | |
for var in rhat: | |
if rhat[var] >= 1.2: | |
warning_str += "WARNING: Lack of convergence for " + str(var) + ', with rhat: ' + str(rhat[var]) + "\n" | |
if warning_str: | |
plt.text(0.8, 0.2, warning_str.strip(), horizontalalignment='center', verticalalignment='center', | |
transform=ax.transAxes, bbox={'facecolor': 'red', 'alpha': 0.5, 'pad': 5}) | |
ax.set_xlabel('Time (ms)') | |
ax1 = fig.add_subplot(111) | |
pm.plot_posterior(trace, rope=ROPE, varnames=['diff of means'], | |
ax=ax1, color='#ff8c00') | |
ax1.set_xlabel('Time (ms)') | |
plt.subplots_adjust(bottom=1.1, top=2) | |
truncated_name = name[name.find('/') + 1:] | |
plt.title('A: ' + prev_git_commit + '| B: ' + git_commit) | |
plt.suptitle(truncated_name + ", Diff_Means:|B - A|", y=2.1, fontsize=16) | |
return fig |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment