Created
October 27, 2017 17:44
-
-
Save sammosummo/a169c871c5950255b7d6189973b38ac1 to your computer and use it in GitHub Desktop.
Bayesian hierarchical logistic regression
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Simple fully contained script to fit a Bayesian hierarchical logistic | |
regression model using PyMC3. | |
""" | |
import theano | |
import matplotlib | |
matplotlib.use('Agg') # seems to be necessary on Unix | |
import numpy as np | |
import pandas as pd | |
import pymc3 as pm | |
import theano.tensor as tt | |
from patsy import dmatrix | |
import matplotlib.pyplot as plt | |
def compress(data, outcome, formula, submodel): | |
"""Compress the data frame from "Bernoulli" format (one row per trial) to | |
"binomial" format (one row per condition). This can dramatically speed up | |
sampling. | |
Args: | |
data (pd.DataFrame): Data in Bernoulli format. | |
outcome (str): Name of the outcome variable (i.e., dependent variable) | |
formula (str): Patsy formula specifying the independent variables. | |
submodel (str): Name of variable used to index submodels. | |
Returns: | |
pd.DataFrame: Data in binomial format. | |
""" | |
predictors = [t for t in data.columns if t in formula or t == submodel] | |
entries = [] | |
for idx, df in data.groupby(predictors): | |
entry = {k: v for k, v in zip(predictors, idx)} | |
entry['n%ss' % outcome] = df[outcome].sum() | |
entry['ntrials'] = len(df) | |
entries.append(entry) | |
return pd.DataFrame(entries) | |
def model(data, outcome, formula, submodel): | |
"""Constructs model and places it in context, ready for sampling. | |
Args: | |
data (pd.DataFrame): Data in binomial format. | |
outcome (str): Name of the outcome variable (i.e., dependent variable) | |
formula (str): Patsy formula specifying the independent variables. | |
submodel (str): Name of variable used to index submodels. | |
Returns: | |
None: Model placed in context. | |
""" | |
design_matrix = dmatrix(formula, data) | |
submodel_names = data[submodel].unique() | |
sub_ix = data[submodel].replace( | |
{r: i for i, r in enumerate(submodel_names)}).values | |
betas = [] | |
print(design_matrix.design_info.column_names) | |
print(design_matrix) | |
for n in design_matrix.design_info.column_names: | |
n = n.replace('"', '').replace("'", '').replace(' ', '').replace(',','') | |
μ = pm.Cauchy(name='μ_' + n, alpha=0, beta=5) | |
σ = pm.HalfCauchy(name='σ_' + n, beta=5) | |
δ = [pm.Normal( | |
name='δ_(%s=%s)_(condition=%s)' % (submodel, r, n), mu=0., sd=1. | |
) for r in submodel_names] | |
β = [pm.Deterministic( | |
'β_(%s=%s)_(condition=%s)' % (submodel, r, n), μ + d * σ | |
) for d, r in zip(δ, submodel_names)] | |
betas.append(β) | |
B = tt.stack(betas, axis=0).T | |
p = pm.invlogit(tt.sum(np.asarray(design_matrix) * B[sub_ix], axis=1)) | |
pm.Binomial( | |
name='n%ss' % outcome, | |
p=p, | |
n=data.ntrials.values, | |
observed=data['n%ss' % outcome].values | |
) | |
def run(): | |
"""Run the model. | |
""" | |
with pm.Model(): | |
model_name = '' # give your model a name | |
data = pd.read_csv('') # path to Bernoulli-formatted CSV | |
outcome = '' # DV column name | |
formula = '' # Patsy-style formula | |
submodel = '' # submodel index column | |
data = compress(data, outcome, formula, submodel) | |
print(data.head()) | |
model(data, outcome, formula, submodel) | |
backend = pm.backends.Text(model_name) | |
trace = pm.sample(10000, tune=2000, trace=backend) | |
params = open('%s/chain-0.csv' % model_name).readline().split(',') | |
params = [p for p in params if 'μ' in p] | |
pm.traceplot(trace, params) | |
plt.savefig('%s/traceplot.png' % model_name) | |
plt.clf() | |
pm.plot_posterior(trace, params) | |
plt.savefig('%s/posteriors.png' % model_name) | |
plt.clf() | |
pm.df_summary(trace).to_csv('%s/summary.csv' % model_name) | |
if __name__ == '__main__': | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment