Created
April 19, 2017 08:14
-
-
Save maedoc/6cee50c3133774c0aeb1e7e7a51e5b1d to your computer and use it in GitHub Desktop.
Optimizing log probability by hand
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
print('importing') | |
import numpy as np | |
import sympy as sp | |
import scipy.optimize | |
import random | |
print('functions') | |
def gen_trace(n_time, sig, dt, a): | |
x = random.gauss(0, 3) | |
xs = [] | |
for i in range(n_time): | |
dx = x - x**3/3 + a | |
x = x + dt * dx + sig * random.gauss(0, sig) | |
xs.append(x) | |
return np.array(xs) | |
def evalcse(cse, **args): | |
ns = {} | |
ns.update(args) | |
for k, v in cse[0]: | |
try: | |
ns[str(k)] = eval(str(v), ns) | |
except Exception as exc: | |
print(exc, k, v) | |
raise exc | |
result = [] | |
for v in cse[1]: | |
try: | |
result.append(eval(str(v), ns)) | |
except Exception as exc: | |
print(exc, k, v) | |
raise exc | |
return result | |
def N(x, mu, sd): | |
return -(mu - x)**2 / (2 * sd ** 2) | |
print('symbolics') | |
def make_cse(): | |
x, xn, sig, dt, o, a = sp.symbols('x xn sig dt o a') | |
mu = x + dt * (x - x**3/3 + a) | |
logp = N(xn, mu, sig) + N(o, x, 0.1) | |
exprs = [logp] + [logp.diff(var) for var in [dt, sig, a, x, xn]] | |
cse = sp.cse(exprs) | |
print(cse) | |
return cse | |
# 0 1 2 3 4 5 6 | |
# x - - - - - - | |
# xn - - - - - - | |
# lp lp lp lp lp | |
def f(x, cse, dt, sig, o): | |
a = x[0] | |
x = x[1:] | |
lp, _, _, _, _, _= evalcse(cse, a=a, dt=dt, sig=sig, x=x[:-1], xn=x[1:], o=o[:-1]) | |
return -lp.sum() | |
def zeros_like(x): | |
if isinstance(x, np.ndarray): | |
return np.zeros_like(x) | |
def fp(x, cse, dt, sig, o): | |
grad = zeros_like(x) | |
a = x[0] | |
x = x[1:] | |
lp, _, _, lp_a, lp_x, lp_xn = evalcse(cse, a=a, dt=dt, sig=sig, x=x[:-1], xn=x[1:], o=o[:-1]) | |
grad[0] = lp_a.sum() | |
grad[1:-1] = lp_x | |
grad[2:] += lp_xn | |
return -grad | |
ft = np.float32 | |
dt = ft(0.1) | |
sig = ft(0.2) | |
a = ft(-.05) | |
x = np.random.randn(200).astype(ft) | |
o = gen_trace(x.size, dt=dt, sig=sig, a=a).astype(ft) | |
from pylab import plot, show, cla, title | |
cla() | |
plot(o, 'k') | |
show() | |
args = make_cse(), dt, sig, o | |
x0 = np.r_[1.0, x].astype(ft) | |
oh = scipy.optimize.fmin_bfgs(f, x0, fp, args) | |
ah = oh[0] | |
xh = oh[1:] | |
plot(xh - 0.1, 'r--') | |
title('ah = %f' % (ah, )) | |
xi = x0 | |
dx = fp(xi, *args) | |
dxn = (dx**2).sum() | |
nge = 0 | |
dxn1 = dxn + 10 | |
while nge < 5 or (dxn1 - dxn) > 1e-4: | |
dxn1 = dxn.copy() | |
dx = fp(xi, *args) | |
dxn = (dx**2).sum() | |
xi -= 0.01 * dx | |
nge += 1 | |
print('naive descent', nge, 'evals') | |
plot(xi[1:] + 0.1, 'b', alpha=0.2) | |
title('ah = %f, ah2 = %f' % (ah, xi[0])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment