Last active
October 2, 2017 22:43
-
-
Save maedoc/11696f26a9700c5062e1e1d263274305 to your computer and use it in GitHub Desktop.
Centering for SDEs in Stan?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
int T; | |
real sig; | |
real dt; | |
real m; | |
real b; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include prelude.stan | |
data { | |
#include basedata.stan | |
vector[T] z; | |
} | |
parameters { | |
real mh; | |
real bh; | |
vector[T] dWt; | |
real z_1; | |
} | |
transformed parameters { | |
vector[T] zh; | |
zh[1] = z_1; | |
for (t in 1:(T - 1)) | |
zh[t + 1] = f(zh[t], mh, bh, dt) + dWt[t]; | |
} | |
model { | |
mh ~ normal(1.0, 1.0); | |
bh ~ normal(1.5, 1.0); | |
dWt ~ normal(0.0, sqrt(dt) * sig); | |
z ~ normal(zh, 0.1); | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CMDSTAN ?= ~/Downloads/cmdstan-2.17.0 | |
nsim ?= 1 | |
./%: %.stan prelude.stan | |
$(CMDSTAN)/bin/stanc --o=$*.hpp $*.stan | |
here=$(shell pwd); cd $(CMDSTAN); make $$here/$@; cd $$here | |
sim.csv: sim.data.R ./sim | |
./sim sample algorithm=fixed_param num_samples=$(nsim) data file=sim.data.R output file=sim.csv | |
data.R: sim.csv sim.data.R | |
python util.py sim2data sim.csv data.R | |
cat sim.data.R >> data.R | |
naive.csv: data.R ./naive | |
./naive sample data file=data.R output file=naive.csv | |
centered.csv: data.R ./centered | |
./centered sample data file=data.R output file=centered.csv |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include prelude.stan | |
data { | |
#include basedata.stan | |
vector[T] z; | |
} | |
parameters { | |
real mh; | |
real bh; | |
vector[T] zh; | |
} | |
model { | |
mh ~ normal(1.0, 1.0); | |
bh ~ normal(1.5, 1.0); | |
for (t in 1:(T - 1)) | |
zh[t + 1] ~ normal(f(zh[t], mh, bh, dt), sqrt(dt) * sig); | |
z ~ normal(zh, 0.1); | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
functions { | |
// step a double well system, shifted | |
real f(real z_, real m, real b, real dt) { | |
real z = z_ - b; | |
return z + dt * (m*z - z*z*z) + b; | |
} | |
// w/ stochastic increment dWt | |
real f_incr(real z, real m, real b, real sig_dWt, real dt) { | |
return f(z, m, b, dt) + sqrt(dt) * sig_dWt; | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
T <- 1000 | |
m <- 1.0 | |
b <- 1.5 | |
sig <- 0.5 | |
dt <- 0.1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include prelude.stan | |
data { | |
#include basedata.stan | |
} | |
model { } | |
generated quantities { | |
vector[T] z; | |
z[1] = b; | |
for (t in 1:(T - 1)) { | |
z[t + 1] = f_incr(z[t], m, b, normal_rng(0.0, sig), dt); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def _rdump_array(key, val): | |
c = 'c(' + ', '.join(map(str, val.flat)) + ')' | |
if (val.size,) == val.shape: | |
return '{key} <- {c}'.format(key=key, c=c) | |
else: | |
dim = '.Dim = c{0}'.format(val.shape[::-1]) | |
struct = '{key} <- structure({c}, {dim})'.format( | |
key=key, c=c, dim=dim) | |
return struct | |
def rdump(fname, data): | |
with open(fname, 'w') as fd: | |
for key, val in data.items(): | |
if isinstance(val, np.ndarray) and val.size > 1: | |
line = _rdump_array(key, val) | |
else: | |
line = '%s <- %s' % (key, val.flat[0]) | |
fd.write(line) | |
fd.write('\n') | |
def merge_csv_data(*csvs): | |
data_ = {} | |
for csv in csvs: | |
for key, val in csv.items(): | |
if key in data_: | |
data_[key] = np.concatenate( | |
(data_[key], val), | |
axis=0 | |
) | |
else: | |
data_[key] = val | |
return data_ | |
def parse_csv(fname): | |
if isinstance(fname, (list, tuple)): | |
return merge_csv_data(*[parse_csv(_) for _ in fname]) | |
print('parsing %r' % (fname,)) | |
lines = [] | |
with open(fname, 'r') as fd: | |
for line in fd.readlines(): | |
if not line.startswith('#'): | |
lines.append(line.strip().split(',')) | |
names = [field.split('.') for field in lines[0]] | |
data = np.array([[float(f) for f in line] for line in lines[1:]]) | |
namemap = {} | |
maxdims = {} | |
for i, name in enumerate(names): | |
if name[0] not in namemap: | |
namemap[name[0]] = [] | |
namemap[name[0]].append(i) | |
if len(name) > 1: | |
maxdims[name[0]] = name[1:] | |
for name in maxdims.keys(): | |
dims = [] | |
for dim in maxdims[name]: | |
dims.append(int(dim)) | |
maxdims[name] = tuple(reversed(dims)) | |
# data in linear order per Stan, e.g. mat is col maj | |
data_ = {} | |
for name, idx in namemap.items(): | |
new_shape = (-1, ) + maxdims.get(name, ()) | |
data_[name] = data[:, idx].reshape(new_shape) | |
return data_ | |
def sim2data(csv_fname, rdat_fname): | |
dat = parse_csv(csv_fname) | |
rdump(rdat_fname, {k:v[0] for k, v in dat.items() if not k.endswith('__')}) | |
def plotsim(csv_fname): | |
dat = parse_csv(csv_fname) | |
import pylab as pl | |
z = dat['z'] | |
pl.plot(z.T, 'k') | |
pl.show() | |
if __name__ == '__main__': | |
import sys | |
eval(sys.argv[1])(*sys.argv[2:]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment