Created
June 23, 2017 07:45
-
-
Save maedoc/18ab9979f14d4e532aec60444f39fbdc to your computer and use it in GitHub Desktop.
tvb algorithms, distilled
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numpy import * | |
import numpy as np | |
from numpy.random import randn, rand | |
def spmat(A): | |
n = A.shape[0] | |
m = A != 0 # non-zero mask | |
a = A[m] # non-zero elements | |
r, c = argwhere(m).T # non-zero row & col indices | |
nnz = a.size # number of non-zero elements | |
lri, = argwhere(diff(r_[-1, r])).T # local reduction indices | |
nzr = unique(r) # rows with non-zeros | |
def mvmult(b): | |
out = zeros_like(b) | |
# we only touch non-zeros, so only assign to rows with non-zeros | |
out[nzr] = add.reduceat(a * b.take(c), lri) | |
return out | |
return mvmult | |
def test_spmat(): | |
from scipy.sparse import csr_matrix | |
n = 300 | |
A = randn(n, n) | |
A[rand(n, n) < 0.8] = 0 | |
b = randn(n) | |
sa = spmat(A) | |
csra = csr_matrix(A) | |
assert allclose(sa(b), A.dot(b)) | |
def conn(W, D, dt, pre, post, ncv, cut=0, icf=lambda h: h): | |
n = W.shape[0] | |
m = W > cut # non-zero mask | |
w = W[m] # non-zero weights | |
d = D[m] # non-zero delays | |
di = (d / dt).astype('i') # non-zero delays in time steps | |
r, c = argwhere(m).T # non-zero row & col indices | |
nnz = w.size # number of non-zero elements | |
lri, = argwhere(diff(r_[-1, r])).T # local reduction indices | |
nzr = unique(r) # rows with non-zeros | |
H = di.max() + 1 | |
hist = icf(zeros((H, n, ncv))) | |
def prop(i, xi): | |
hist[i % H] = xi | |
xj = hist[(i - di) % H, c] | |
gx = add.reduceat((w * pre(xi[c], xj).T).T, lri) | |
out = zeros_like(xi) | |
out[nzr] = post(gx) | |
return out | |
return prop | |
def test_conn(): | |
def pre(xi, xj): | |
return xj - xi | |
def post(gx): | |
return 0.1 * gx - 0.2 | |
prop = conn(A, A, 0.1, pre, post, 1) | |
prop(23, b.reshape((-1, 1))).shape | |
def v2r(rmap, vtx): | |
out = zeros((rmap.max() + 1, ) + vtx.shape[1:]) | |
add.at(out, rmap, vtx) | |
return out | |
def r2v(rmap, reg): | |
return reg[rmap] | |
def test_rmap(): | |
nv = 5000 | |
vtx = r_[:nv][:, newaxis] | |
rmap = randint(0, n, nv) | |
r2v(rmap, prop(23, v2r(rmap, vtx))).shape | |
def exact_prop_test(): | |
n = 4 | |
# build simple connectivity | |
W = zeros((n, n)) | |
for i in range(n - 1): | |
W[i, i + 1] = 1 | |
D = r_[:n * n].reshape((n, n)) | |
# custom initial conditions function | |
def icf(hist): | |
hist[:] = 1.0 | |
return hist | |
# make propagator | |
prop = conn(W, D, 1, lambda i, j: j, lambda g: g, 1, icf=icf) | |
# run simulation | |
x = ones((n, 1)) | |
xs = [] | |
for i in range(10): | |
# equiv. to Sum(Model) w/ Identity(Integrator) in TVB test | |
x += prop(i, x) | |
xs.append(x.flat[:]) | |
# correct output | |
xs_ = numpy.array([ | |
[2., 2., 2., 1.], | |
[3., 3., 3., 1.], | |
[5., 4., 4., 1.], | |
[8., 5., 5., 1.], | |
[12., 6., 6., 1.], | |
[17., 7., 7., 1.], | |
[23., 8., 8., 1.], | |
[30., 10., 9., 1.], | |
[38., 13., 10., 1.], | |
[48., 17., 11., 1.]]) | |
numpy.allclose(array(xs), xs_) | |
def test_em_white(): | |
# step | |
def em_white(f, dt, D): | |
s = sqrt(2 * D * dt) | |
def step(x): | |
gw = s * randn() | |
return x + dt * f(x) + gw | |
return step | |
# de | |
f = lambda x: (x - x**3/3) + 0.67 | |
step = em_white(f, 0.01, 1e-3) | |
# time step | |
for i in range(50): | |
x, xs = -2.0, [] | |
for i in range(10000): | |
x = step(x) | |
xs.append(x) | |
plot(xs, 'k', alpha=0.5) | |
def test_em_color(): | |
# step | |
def em_color(f, dt, D, l): | |
e = sqrt(D * l) * randn() | |
E = exp(-l * dt) | |
def step(x, e): | |
x += dt * (f(x) + e) | |
h = sqrt(D * l * (1 - E**2)) * randn() | |
e = e * E + h | |
return x, e | |
return step, e | |
# de | |
f = lambda x: (x - x**3/3) + 0.67 | |
step, e = em_color(f, 0.01, 1e-3, 10.0) | |
# time step | |
for i in range(50): | |
x, xs = -2.0, [] | |
for _ in range(10000): | |
x, e = step(x, e) | |
xs.append(x) | |
plot(xs, 'k', alpha=0.5) | |
def test_general_em_color(): | |
def em_color(f, g, Δt, λ, x): | |
ϵ = sqrt(g(x) * λ) * randn() | |
E = exp(-λ * Δt) | |
while True: | |
x += Δt * (f(x) + ϵ) | |
h = sqrt(g(x) * λ * (1 - E**2)) * randn() | |
ϵ = ϵ * E + h | |
yield x, ϵ | |
f = lambda x: (x - x**3/3) + 0.67 | |
g = lambda x: 1e-3 | |
for i in range(50): | |
xs = [] | |
for x, ϵ in em_color(f, g, 0.01, 10.0, -2.0): | |
xs.append(x) | |
if len(xs) == 10000: | |
break | |
plot(xs, 'k', alpha=0.5) | |
def em_color(f, g, Δt, λ, x): | |
i = 0 | |
nd = x.shape | |
ϵ = sqrt(g(i, x) * λ) * randn(*nd) | |
E = exp(-λ * Δt) | |
while True: | |
yield x, ϵ | |
i += 1 | |
x += Δt * (f(i, x) + ϵ) | |
h = sqrt(g(i, x) * λ * (1 - E**2)) * randn(*nd) | |
ϵ = ϵ * E + h | |
def test_nd_em_color(): | |
f = lambda i, x: x - x**3/3 - sum(x) | |
g = lambda i, x: exp(x) * 0.5 | |
X = zeros(3) | |
Xs = zeros((10000, X.size)) | |
T = r_[:Xs.shape[0]] | |
for t, (x, _) in zip(T, em_color(f, g, 0.01, 0.5, X)): | |
Xs[t] = x | |
figure(figsize=(10, 5)) | |
subplot(121), plot(Xs) | |
subplot(122), hist(Xs.flat[:], 100, color='k'); | |
def load_wd(): | |
import zipfile, urllib.request | |
zf = zipfile.ZipFile(urllib.request.urlretrieve( | |
'https://github.com/the-virtual-brain/tvb-data/' | |
'raw/master/tvb_data/connectivity/connectivity_76.zip')[0]) | |
W = loadtxt(zf.open('weights.txt')) | |
D = loadtxt(zf.open('tract_lengths.txt')) | |
return W, D | |
def demo_delays_no_plot(): | |
W, D = load_wd() | |
def sim(dt=0.05, tf=150.0, k=0.0, speed=1.0, ω=1.0): | |
n = W.shape[0] | |
pre = lambda i, j: j - 1.0 | |
post = lambda gx: k * gx | |
prop = conn(W, D / speed, dt, pre, post, 1) | |
def f(i, X): # monostable | |
x, y = X.T | |
c, = prop(i, x.reshape((-1, 1))).T | |
dx = ω * (x - x**3/3 + y) * 3.0 | |
dy = ω * (1.01 - x + c) / 3.0 | |
return array([dx, dy]).T | |
def g(i, X): # additive linear noise | |
return sqrt(1e-9) | |
X = zeros((n, 2)) | |
Xs = zeros((int(tf/dt), ) + X.shape) | |
T = r_[:Xs.shape[0]] | |
for t, (x, _) in zip(T, em_color(f, g, dt, 1e-1, X)): | |
if t == 0: | |
x[:] = -1.0 | |
if t == 1: | |
x[:] = rand(n, 2)/5 + r_[1.0, -0.6] | |
Xs[t] = x | |
return T, Xs | |
dt = 0.05 | |
from time import time | |
for i, speed in enumerate([1.0, 2.0, 10.0]): | |
tic = time() | |
t, x = sim(dt, 150.0, 1e-3, speed) | |
elapsed = time() - tic | |
print('%.3fs elapsed' % (elapsed, )) | |
def demo_delays(): | |
W, D = load_wd() | |
def sim(dt=0.05, tf=150.0, k=0.0, speed=1.0, ω=1.0): | |
n = W.shape[0] | |
pre = lambda i, j: j - 1.0 | |
post = lambda gx: k * gx | |
prop = conn(W, D / speed, dt, pre, post, 1) | |
def f(i, X): # monostable | |
x, y = X.T | |
c, = prop(i, x.reshape((-1, 1))).T | |
dx = ω * (x - x**3/3 + y) * 3.0 | |
dy = ω * (1.01 - x + c) / 3.0 | |
return array([dx, dy]).T | |
def g(i, X): # additive linear noise | |
return sqrt(1e-9) | |
X = zeros((n, 2)) | |
Xs = zeros((int(tf/dt), ) + X.shape) | |
T = r_[:Xs.shape[0]] | |
for t, (x, _) in zip(T, em_color(f, g, dt, 1e-1, X)): | |
if t == 0: | |
x[:] = -1.0 | |
if t == 1: | |
x[:] = rand(n, 2)/5 + r_[1.0, -0.6] | |
Xs[t] = x | |
return T, Xs | |
dt = 0.05 | |
figure(figsize=(12, 6)) | |
from time import time | |
elapsed = 0.0 | |
for i, speed in enumerate([1.0, 2.0, 10.0]): | |
tic = time() | |
t, x = sim(dt, 150.0, 1e-3, speed) | |
elapsed += time() - tic | |
subplot(2, 3, i + 1) | |
plot(t[::5] * dt, x[::5, :, 0] + 0 * r_[:W.shape[0]], 'k', alpha=0.3) | |
grid(True, axis='x') | |
xlim([0, t[-1] * dt]) | |
title('Speed = %g mm/ms' % (speed, )) | |
xlabel('time (ms)') | |
ylabel('X(t)') | |
subplot(2, 3, i + 4) | |
hist((D[W!=0] / speed).flat[:], 100, color='k') | |
xlim([0, t[-1] * dt]) | |
grid(True) | |
xlabel('delay (ms)') | |
ylabel('# delay') | |
tight_layout() | |
print('%.3fs elapsed' % (elapsed, )) | |
show() | |
def _(): | |
ω = 10.0 / 1e3 * 2 * pi | |
dt = (1 / ω) / 10.0 | |
t, x = sim(dt, 5e3, -1e-2, 10.0, ω=0.02 * pi) | |
t = t * dt * 1e-3 | |
# estimate spectrum | |
Fx = abs(fft.fft(x[t>1, :, 0], axis=0).mean(axis=1)) | |
fs = fft.fftfreq(Fx.size, dt*1e-3) | |
figure(figsize=(10, 5)) | |
plot(t[::5], x[::5, :, 0] + 1 * r_[:W.shape[0]], 'k', alpha=0.3), xlabel('time (s)'), ylabel('$x_i(t)$'), ylim([-20, 80]) | |
ax = axes([0.6, 0.25, 0.3, 0.3], axisbg='w') | |
ax.loglog(fs[fs>=0], Fx[fs>=0], 'k'), ax.set_yticks([]), ax.set_xlabel('Freq (Hz)'), grid(1); | |
tight_layout() | |
from time import time as t | |
ds = r_[5.0, 10.0, 20.0, 50.0, 100.0] | |
et = [] | |
for ds_i in ds: | |
dt = (1 / ω) / ds_i | |
tic = t() | |
_ = sim(dt, 1e3, -1e-2, 10.0, ω=0.02 * pi) | |
et.append(t() - tic) | |
loglog((1 / ω) / ds, 1/array(et), 'ko') | |
xlabel('$\Delta t$') | |
title('Speed up over realtime') | |
grid(True) | |
demo_delays_no_plot() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Can probably apply autograd directly on these functions, as long as it's not gradient wrt delays.