Skip to content

Instantly share code, notes, and snippets.

@EoinTravers
Last active April 30, 2020 13:05
Show Gist options
  • Save EoinTravers/0e856173b67a86718e65a7c9c7b8358a to your computer and use it in GitHub Desktop.
Save EoinTravers/0e856173b67a86718e65a7c9c7b8358a to your computer and use it in GitHub Desktop.
Testing DAG in brms/rstan
library(brms)
library(rstan)
## Simulate data
bxy = 2
bxz = 3
byz = 4
e = .1
n = 1000
x = rnorm(n)
y = bxy * x + rnorm(n, 0, e)
z = bxz * x + byz * y + rnorm(n, 0, e)
df = data.frame(x=x, y=y, z=z)
## Using brms
my_formula = bf(y ~ x) + bf(z ~ x + y)
my_model = brm(my_formula, data=df, chains = 2, cores = 2)
summary(my_model)
# Family: MV(gaussian, gaussian)
# Links: mu = identity; sigma = identity
# mu = identity; sigma = identity
# Formula: y ~ x
# z ~ x + y
# Data: df (Number of observations: 1000)
# Samples: 2 chains, each with iter = 2000; warmup = 1000; thin = 1;
# total post-warmup samples = 2000
#
# Population-Level Effects:
# Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
# y_Intercept -0.00 0.00 -0.01 0.01 1.02 119 304
# z_Intercept -0.00 0.01 -0.02 0.01 1.01 77 84
# y_x 2.00 0.00 2.00 2.01 1.01 189 204
# z_x 2.81 3.70 -5.72 10.15 1.05 18 14
# z_y 4.09 1.85 0.43 8.36 1.05 18 14
#
# Family Specific Parameters:
# Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
# sigma_y 0.10 0.00 0.09 0.10 1.01 90 201
# sigma_z 0.19 0.10 0.10 0.46 1.03 39 36
#
# Residual Correlations:
# Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
# rescor(y,z) -0.03 0.69 -0.97 0.96 1.05 18 14
#
# Samples were drawn using sampling(NUTS). For each parameter, Eff.Sample
# is a crude measure of effective sample size, and Rhat is the potential
# scale reduction factor on split chains (at convergence, Rhat = 1).
## Using rstan
stan_data = as.list(df)
stan_data$N = n
stan_model0 = stan('dag.stan', data=stan_data, cores = 1, chains = 1, iter = 0)
stan_model = stan(fit=stan_model0, data=stan_data, cores = 1, chains = 1, iter = 2000)
stan_model
# Inference for Stan model: dag.
# 1 chains, each with iter=2000; warmup=1000; thin=1;
# post-warmup draws per chain=1000, total post-warmup draws=1000.
#
# mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
# ay 0.00 0.0 0.00 -0.01 0.00 0.00 0.00 0.01 838 1.00
# az 0.00 0.0 0.00 -0.01 0.00 0.00 0.00 0.00 1180 1.00
# bxy 2.00 0.0 0.00 2.00 2.00 2.00 2.01 2.01 996 1.00
# bxz 3.05 0.0 0.07 2.93 3.00 3.05 3.10 3.19 360 1.01
# byz 3.97 0.0 0.04 3.90 3.95 3.97 4.00 4.04 363 1.01
# sy 0.10 0.0 0.00 0.09 0.10 0.10 0.10 0.10 525 1.00
# sz 0.10 0.0 0.00 0.10 0.10 0.10 0.11 0.11 619 1.00
# lp__ 3577.55 0.1 1.90 3572.83 3576.57 3577.92 3578.86 3580.17 397 1.01
#
# Samples were drawn using NUTS(diag_e) at Thu Apr 30 13:55:24 2020.
# For each parameter, n_eff is a crude measure of effective sample size,
# and Rhat is the potential scale reduction factor on split chains (at
# convergence, Rhat=1).
data {
int<lower=1> N; // number of observations
vector[N] x;
vector[N] y;
vector[N] z;
}
parameters {
// Intercepts
real ay;
real az;
// Slopes
real bxy; // Effect of x on y
real bxz;
real byz;
// Sigmas
real<lower=0> sy;
real<lower=0> sz;
}
model {
y ~ normal(ay + bxy * x, sy);
z ~ normal(az + bxz * x + byz * y, sz);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment