Created
March 22, 2021 21:56
-
-
Save sammosummo/703f60103c16001dedd38977abc9c25a to your computer and use it in GitHub Desktop.
Pure Python implementation of the full DDM log-likelihood function with exponential contaminants
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Attempts to implement DDM likelihoods in Python. | |
""" | |
from math import pi, sqrt, log, ceil, floor, exp, sin, fabs, inf | |
def simpson_1d(x, v, sv, a, z, t, err, lb_z, ub_z, n_sz, lb_t, ub_t, n_st): | |
n = max(n_st, n_sz) | |
if n_st == 0: # integration over z | |
hz = (ub_z - lb_z) / n | |
ht = 0 | |
lb_t = t | |
ub_t = t | |
else: # integration over t | |
hz = 0 | |
ht = (ub_t - lb_t) / n | |
lb_z = z | |
ub_z = z | |
s = pdf_sv(x - lb_t, v, sv, a, lb_z, err) | |
for i in range(1, n + 1): | |
z_tag = lb_z + hz * i | |
t_tag = lb_t + ht * i | |
y = pdf_sv(x - t_tag, v, sv, a, z_tag, err) | |
if i & 1: # check if i is odd | |
s += 4 * y | |
else: | |
s += 2 * y | |
s = s - y # the last term should be f(b) and not 2*f(b) so we subtract y | |
s = s / ( | |
(ub_t - lb_t) + (ub_z - lb_z) | |
) # the right function if pdf_sv()/sz or pdf_sv()/st | |
return (ht + hz) * s / 3 | |
def simpson_2d(x, v, sv, a, z, t, err, lb_z, ub_z, n_sz, lb_t, ub_t, n_st): | |
ht = (ub_t - lb_t) / n_st | |
s = simpson_1d(x, v, sv, a, z, lb_t, err, lb_z, ub_z, n_sz, 0, 0, 0) | |
for i_t in range(1, n_st + 1): | |
t_tag = lb_t + ht * i_t | |
y = simpson_1d(x, v, sv, a, z, t_tag, err, lb_z, ub_z, n_sz, 0, 0, 0) | |
if i_t & 1: # check if i is odd | |
s += 4 * y | |
else: | |
s += 2 * y | |
s = s - y # the last term should be f(b) and not 2*f(b) so we subtract y | |
s = s / (ub_t - lb_t) | |
return ht * s / 3 | |
def adapt_simpson_aux( | |
x, | |
v, | |
sv, | |
a, | |
z, | |
t, | |
pdf_err, | |
lb_z, | |
ub_z, | |
lb_t, | |
ub_t, | |
ZT, | |
simps_err, | |
S, | |
f_beg, | |
f_end, | |
f_mid, | |
bottom, | |
): | |
if (ub_t - lb_t) == 0: # integration over sz | |
h = ub_z - lb_z | |
z_c = (ub_z + lb_z) / 2.0 | |
z_d = (lb_z + z_c) / 2.0 | |
z_e = (z_c + ub_z) / 2.0 | |
t_c = t | |
t_d = t | |
t_e = t | |
else: # integration over t | |
h = ub_t - lb_t | |
t_c = (ub_t + lb_t) / 2.0 | |
t_d = (lb_t + t_c) / 2.0 | |
t_e = (t_c + ub_t) / 2.0 | |
z_c = z | |
z_d = z | |
z_e = z | |
fd = pdf_sv(x - t_d, v, sv, a, z_d, pdf_err) / ZT | |
fe = pdf_sv(x - t_e, v, sv, a, z_e, pdf_err) / ZT | |
Sleft = (h / 12) * (f_beg + 4 * fd + f_mid) | |
Sright = (h / 12) * (f_mid + 4 * fe + f_end) | |
S2 = Sleft + Sright | |
if bottom <= 0 or fabs(S2 - S) <= 15 * simps_err: | |
return S2 + (S2 - S) / 15 | |
return adapt_simpson_aux( | |
x, | |
v, | |
sv, | |
a, | |
z, | |
t, | |
pdf_err, | |
lb_z, | |
z_c, | |
lb_t, | |
t_c, | |
ZT, | |
simps_err / 2, | |
Sleft, | |
f_beg, | |
f_mid, | |
fd, | |
bottom - 1, | |
) + adapt_simpson_aux( | |
x, | |
v, | |
sv, | |
a, | |
z, | |
t, | |
pdf_err, | |
z_c, | |
ub_z, | |
t_c, | |
ub_t, | |
ZT, | |
simps_err / 2, | |
Sright, | |
f_mid, | |
f_end, | |
fe, | |
bottom - 1, | |
) | |
def adapt_simpson_1d( | |
x, v, sv, a, z, t, pdf_err, lb_z, ub_z, lb_t, ub_t, simps_err, maxRecursionDepth | |
): | |
if (ub_t - lb_t) == 0: # integration over z | |
lb_t = t | |
ub_t = t | |
h = ub_z - lb_z | |
else: # integration over t | |
h = ub_t - lb_t | |
lb_z = z | |
ub_z = z | |
ZT = h | |
c_t = (lb_t + ub_t) / 2.0 | |
c_z = (lb_z + ub_z) / 2.0 | |
f_beg = pdf_sv(x - lb_t, v, sv, a, lb_z, pdf_err) / ZT | |
f_end = pdf_sv(x - ub_t, v, sv, a, ub_z, pdf_err) / ZT | |
f_mid = pdf_sv(x - c_t, v, sv, a, c_z, pdf_err) / ZT | |
S = (h / 6) * (f_beg + 4 * f_mid + f_end) | |
res = adapt_simpson_aux( | |
x, | |
v, | |
sv, | |
a, | |
z, | |
t, | |
pdf_err, | |
lb_z, | |
ub_z, | |
lb_t, | |
ub_t, | |
ZT, | |
simps_err, | |
S, | |
f_beg, | |
f_end, | |
f_mid, | |
maxRecursionDepth, | |
) | |
return res | |
def adapt_simpson_aux_2d( | |
x, | |
v, | |
sv, | |
a, | |
z, | |
t, | |
pdf_err, | |
err_1d, | |
lb_z, | |
ub_z, | |
lb_t, | |
ub_t, | |
st, | |
err_2d, | |
S, | |
f_beg, | |
f_end, | |
f_mid, | |
maxRecursionDepth_sz, | |
bottom, | |
): | |
t_c = (ub_t + lb_t) / 2.0 | |
t_d = (lb_t + t_c) / 2.0 | |
t_e = (t_c + ub_t) / 2.0 | |
h = ub_t - lb_t | |
fd = ( | |
adapt_simpson_1d( | |
x, v, sv, a, z, t_d, pdf_err, lb_z, ub_z, 0, 0, err_1d, maxRecursionDepth_sz | |
) | |
/ st | |
) | |
fe = ( | |
adapt_simpson_1d( | |
x, v, sv, a, z, t_e, pdf_err, lb_z, ub_z, 0, 0, err_1d, maxRecursionDepth_sz | |
) | |
/ st | |
) | |
Sleft = (h / 12) * (f_beg + 4 * fd + f_mid) | |
Sright = (h / 12) * (f_mid + 4 * fe + f_end) | |
S2 = Sleft + Sright | |
if bottom <= 0 or fabs(S2 - S) <= 15 * err_2d: | |
return S2 + (S2 - S) / 15 | |
return adapt_simpson_aux_2d( | |
x, | |
v, | |
sv, | |
a, | |
z, | |
t, | |
pdf_err, | |
err_1d, | |
lb_z, | |
ub_z, | |
lb_t, | |
t_c, | |
st, | |
err_2d / 2, | |
Sleft, | |
f_beg, | |
f_mid, | |
fd, | |
maxRecursionDepth_sz, | |
bottom - 1, | |
) + adapt_simpson_aux_2d( | |
x, | |
v, | |
sv, | |
a, | |
z, | |
t, | |
pdf_err, | |
err_1d, | |
lb_z, | |
ub_z, | |
t_c, | |
ub_t, | |
st, | |
err_2d / 2, | |
Sright, | |
f_mid, | |
f_end, | |
fe, | |
maxRecursionDepth_sz, | |
bottom - 1, | |
) | |
def adapt_simpson_2d( | |
x, | |
v, | |
sv, | |
a, | |
z, | |
t, | |
pdf_err, | |
lb_z, | |
ub_z, | |
lb_t, | |
ub_t, | |
simps_err, | |
maxRecursionDepth_sz, | |
maxRecursionDepth_st, | |
): | |
h = ub_t - lb_t | |
st = ub_t - lb_t | |
c_t = (lb_t + ub_t) / 2.0 | |
c_z = (lb_z + ub_z) / 2.0 | |
err_1d = simps_err | |
err_2d = simps_err | |
f_beg = ( | |
adapt_simpson_1d( | |
x, | |
v, | |
sv, | |
a, | |
z, | |
lb_t, | |
pdf_err, | |
lb_z, | |
ub_z, | |
0, | |
0, | |
err_1d, | |
maxRecursionDepth_sz, | |
) | |
/ st | |
) | |
f_end = ( | |
adapt_simpson_1d( | |
x, | |
v, | |
sv, | |
a, | |
z, | |
ub_t, | |
pdf_err, | |
lb_z, | |
ub_z, | |
0, | |
0, | |
err_1d, | |
maxRecursionDepth_sz, | |
) | |
/ st | |
) | |
f_mid = ( | |
adapt_simpson_1d( | |
x, | |
v, | |
sv, | |
a, | |
z, | |
(lb_t + ub_t) / 2, | |
pdf_err, | |
lb_z, | |
ub_z, | |
0, | |
0, | |
err_1d, | |
maxRecursionDepth_sz, | |
) | |
/ st | |
) | |
S = (h / 6) * (f_beg + 4 * f_mid + f_end) | |
res = adapt_simpson_aux_2d( | |
x, | |
v, | |
sv, | |
a, | |
z, | |
t, | |
pdf_err, | |
err_1d, | |
lb_z, | |
ub_z, | |
lb_t, | |
ub_t, | |
st, | |
err_2d, | |
S, | |
f_beg, | |
f_end, | |
f_mid, | |
maxRecursionDepth_sz, | |
maxRecursionDepth_st, | |
) | |
return res | |
def ftt_01w(tt, w, err=1e-4): | |
"""Compute f(t|0,1,w) according to Navarro and Fuss (2009).""" | |
# calculate number of terms needed for large t | |
if pi * tt * err < 1: # if error threshold is set low enough | |
kl = sqrt(-2 * log(pi * tt * err) / (pi ** 2 * tt)) # bound | |
kl = max(kl, 1.0 / (pi * sqrt(tt))) # ensure boundary conditions met | |
else: # if error threshold set too high | |
kl = 1.0 / (pi * sqrt(tt)) # set to boundary condition | |
# calculate number of terms needed for small t | |
if 2 * sqrt(2 * pi * tt) * err < 1: # if error threshold is set low enough | |
ks = 2 + sqrt(-2 * tt * log(2 * sqrt(2 * pi * tt) * err)) # bound | |
ks = max(ks, sqrt(tt) + 1) # ensure boundary conditions are met | |
else: # if error threshold was set too high | |
ks = 2 # minimal kappa for that case | |
# compute f(tt|0,1,w) | |
p = 0 # initialize density | |
if ks < kl: # if small t is better (i.e., lambda<0) ... | |
K = ceil(ks) # round to smallest integer meeting error | |
lower = -floor((K - 1) / 2.0) | |
upper = ceil((K - 1) / 2.0) | |
for k in range(lower, upper + 1): # loop over k | |
p += (w + 2 * k) * exp(-(pow((w + 2 * k), 2)) / 2 / tt) # increment sum | |
p /= sqrt(2 * pi * pow(tt, 3)) # add constant term | |
else: # if large t is better ... | |
K = ceil(kl) # round to smallest integer meeting error | |
for k in range(1, K + 1): | |
p += ( | |
k * exp(-(pow(k, 2)) * (pi ** 2) * tt / 2) * sin(k * pi * w) | |
) # increment sum | |
p *= pi # add constant term | |
return p | |
def pdf(x, v, a, w, err=1e-4): | |
"""Compute f(t|v,a,z) according to Navarro and Fuss (2009).""" | |
# time must be positive | |
if x <= 0: | |
return 0 | |
tt = x / a ** 2 # use normalized time | |
p = ftt_01w(tt, w, err) # get f(t|0,1,w) | |
# convert to f(t|v,a,w) | |
return p * exp(-v * a * w - (pow(v, 2)) * x / 2.0) / (pow(a, 2)) | |
def pdf_sv(x, v, sv, a, z, err=1e-4): | |
"""Compute f(t|v,a,z,sv) using the method of Navarro and Fuss (2009), with analytic | |
integration of v according to Tuerlinckx et al. (2001).""" | |
# time must be positive | |
if x <= 0: | |
return 0 | |
# if sv=0 don't integrate | |
if sv == 0: | |
return pdf(x, v, a, z, err) | |
tt = x / (pow(a, 2)) # use normalized time | |
p = ftt_01w(tt, z, err) # get f(t|0,1,w) | |
# TODO: hack to prevent math domain error; fix this! | |
if p == 0: | |
logp = -500 | |
else: | |
logp = log(p) | |
# convert to f(t|v,a,w) | |
return ( | |
exp( | |
logp | |
+ ((a * z * sv) ** 2 - 2 * a * v * z - (v ** 2) * x) | |
/ (2 * (sv ** 2) * x + 2) | |
) | |
/ sqrt((sv ** 2) * x + 1) | |
/ (a ** 2) | |
) | |
def full_pdf( | |
x, v, sv, a, z, sz, t, st, err=1e-4, n_st=2, n_sz=2, use_adaptive=1, simps_err=1e-3 | |
): | |
"""Compute the probability density function of the full drift diffusion model. | |
Computes f(t|v,a,z,sv,sz,st) using the method of Navarro and Fuss (2009) to compute | |
the basic DDM likelihood, analytic integration of v when sv is non-zero (Tuerlinckx | |
et al., 2001), and numeric of t_er and/or z when st and/or sz are non-zero, | |
respectively (Ratcliff & Tuerlinckx, 2002). | |
This function excepts negative or positive reaction times. Negative values | |
correspond to the lower bound whereas positive responses correspond to the upper | |
bound. Drift rates are flipped and biases are inverted for upper-bound responses. | |
""" | |
# transform x, v, z if x is upper bound response | |
if x > 0: | |
v = -v | |
z = 1.0 - z | |
# absolute RT | |
x = fabs(x) | |
# set st and sz to 0 if really small | |
if st < 1e-3: | |
st = 0 | |
if sz < 1e-3: | |
sz = 0 | |
if sv < 1e-3: | |
sv = 0 | |
if sz == 0: | |
if st == 0: # sv=0,sz=0,st=0 | |
return pdf_sv(x - t, v, sv, a, z, err) | |
else: # sv=0,sz=0,st=$ | |
if use_adaptive > 0: | |
return adapt_simpson_1d( | |
x, | |
v, | |
sv, | |
a, | |
z, | |
t, | |
err, | |
z, | |
z, | |
t - st / 2.0, | |
t + st / 2.0, | |
simps_err, | |
n_st, | |
) | |
else: | |
return simpson_1d( | |
x, v, sv, a, z, t, err, z, z, 0, t - st / 2.0, t + st / 2.0, n_st | |
) | |
else: # sz=$ | |
if st == 0: # sv=0,sz=$,st=0 | |
if use_adaptive: | |
return adapt_simpson_1d( | |
x, | |
v, | |
sv, | |
a, | |
z, | |
t, | |
err, | |
z - sz / 2.0, | |
z + sz / 2.0, | |
t, | |
t, | |
simps_err, | |
n_sz, | |
) | |
else: | |
return simpson_1d( | |
x, v, sv, a, z, t, err, z - sz / 2.0, z + sz / 2.0, n_sz, t, t, 0 | |
) | |
else: # sv=0,sz=$,st=$ | |
if use_adaptive: | |
return adapt_simpson_2d( | |
x, | |
v, | |
sv, | |
a, | |
z, | |
t, | |
err, | |
z - sz / 2.0, | |
z + sz / 2.0, | |
t - st / 2.0, | |
t + st / 2.0, | |
simps_err, | |
n_sz, | |
n_st, | |
) | |
else: | |
return simpson_2d( | |
x, | |
v, | |
sv, | |
a, | |
z, | |
t, | |
err, | |
z - sz / 2.0, | |
z + sz / 2.0, | |
n_sz, | |
t - st / 2.0, | |
t + st / 2.0, | |
n_st, | |
) | |
def pdf_contaminant_uniform(x, w=0.05): | |
"""Compute the probability density of a uniform contaminant distribution.""" | |
if -(0.5 / w) <= x <= (0.5 / w): | |
return w | |
else: | |
return 0 | |
def pdf_contaminant_exponential(x, l): | |
"""Compute the probability density of an exponential contaminant distribution.""" | |
return l * exp(-l * fabs(x)) | |
def logpdf_with_contaminant_exponential(x, v, sv, a, z, sz, t, st, p_outlier, l): | |
"""Compute the log likelihood of the full DDM mixed with exponential contaminant | |
distribution.""" | |
# check if all parameters are valid | |
if ( | |
(z < 0) | |
or (z > 1) | |
or (a < 0) | |
or (t < 0) | |
or (st < 0) | |
or (sv < 0) | |
or (sz < 0) | |
or (sz > 1) | |
or (z + sz / 2.0 > 1) | |
or (z - sz / 2.0 < 0) | |
or (t - st / 2.0 < 0) | |
or (p_outlier < 0) | |
or (p_outlier > 1) | |
): | |
return -inf | |
if p_outlier == 0: | |
return log(full_pdf(x, v, sv, a, z, sz, t, st)) | |
else: | |
if l <= 0: | |
return -inf | |
p0 = full_pdf(x, v, sv, a, z, sz, t, st) * (1 - p_outlier) | |
p1 = pdf_contaminant_exponential(x, l) * p_outlier | |
return log(p0 + p1) | |
def test(): | |
from scipy.stats import lognorm | |
import numpy as np | |
np.random.seed(0) | |
x = lognorm.rvs(s=0.5, size=10000) | |
x = np.concatenate([x , -x[:1000]]) | |
v = -1.1 | |
sv = 0.01 | |
a = 1. | |
z = 0.5 | |
sz = 0.01 | |
t = 0.5 | |
st = 0.01 | |
p_outlier = 0.001 | |
l = 0.1 | |
for i, _x in enumerate(x): | |
print("i =", i) | |
print("x =", _x) | |
y = logpdf_with_contaminant_exponential(_x, v, sv, a, z, sz, t, st, p_outlier, l) | |
print("y =", y) | |
if __name__ == '__main__': | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment