Created
March 31, 2021 20:18
-
-
Save sammosummo/24d9fcd3686d3d8ffa72864a85e91ffe to your computer and use it in GitHub Desktop.
JAX implementation of the full DDM log likelihood function
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""JAX functions for calculating the probability density of the Wiener diffusion first- | |
passage time (WFPT) distribution used in drift diffusion models (DDMs). | |
""" | |
import jax | |
import jax.numpy as jnp | |
def jax_wfpt_pdf_sv(x, v, sv, a, z, t): | |
"""Probability density function of the WFPT distribution with drift rates normally | |
distributed over trials. When the standard deviation of drift-rate variability is 0, | |
this reduces down to the "simple" DDM likelihood function without contaminants. | |
Args: | |
x: Reaction times. Responses to the lower bound must be negative. | |
v: Mean drift rate. | |
sv: Standard deviation of drift rate. [0, inf) | |
a: Value of decision upper bound. (0, inf). | |
z: Normalized decision starting point. (0, 1). | |
t: Non-decision time. [0, inf) | |
""" | |
# transform v and z if x is upper-bound response | |
flip = x > 0 | |
v = flip * -v + (1 - flip) * v | |
z = flip * (1 - z) + (1 - flip) * z | |
x = jnp.abs(x) # absolute rts | |
tt = (x - t) / a ** 2 # use normalized time | |
w = z # z is already normalized | |
err = 1e-7 # I don't think this value matters much so long as it's small | |
# determine number of terms needed for small-t expansion | |
_a = 2 * jnp.sqrt(2 * jnp.pi * tt) * err < 1 | |
_b = 2 + jnp.sqrt(-2 * tt * jnp.log(2 * jnp.sqrt(2 * jnp.pi * tt) * err)) | |
_c = jnp.sqrt(tt) + 1 | |
_d = jnp.max(jnp.array([_b, _c]), axis=0) | |
ks = _a * _d + (1 - _a) * 2 | |
# determine number of terms needed for large-t expansion | |
_a = jnp.pi * tt * err < 1 | |
_b = 1.0 / (jnp.pi * jnp.sqrt(tt)) | |
_c = jnp.sqrt(-2 * jnp.log(jnp.pi * tt * err) / (jnp.pi ** 2 * tt)) | |
_d = jnp.max(jnp.array([_b, _c]), axis=0) | |
kl = _a * _d + (1 - _a) * _b | |
# probability calculated with small-t expansion | |
# arange might be more elegant but there were/are issues with it, apparently | |
ps = (w + 2 * -3) * jnp.exp(-jnp.power(w + 2 * -3, 2) / 2 / tt) | |
ps = ps + (w + 2 * -2) * jnp.exp(-jnp.power(w + 2 * -2, 2) / 2 / tt) | |
ps = ps + (w + 2 * -1) * jnp.exp(-jnp.power(w + 2 * -1, 2) / 2 / tt) | |
ps = ps + (w + 2 * 0) * jnp.exp(-jnp.power(w + 2 * 0, 2) / 2 / tt) | |
ps = ps + (w + 2 * 1) * jnp.exp(-jnp.power(w + 2 * 1, 2) / 2 / tt) | |
ps = ps + (w + 2 * 2) * jnp.exp(-jnp.power(w + 2 * 2, 2) / 2 / tt) | |
ps = ps + (w + 2 * 3) * jnp.exp(-jnp.power(w + 2 * 3, 2) / 2 / tt) | |
ps = ps / jnp.sqrt(2 * jnp.pi * jnp.power(tt, 3)) | |
# probability calculated with large-t expansion | |
_x = jnp.power(jnp.pi, 2) * tt / 2 | |
pl = jnp.exp(-jnp.power(1, 2) * _x) * jnp.sin(jnp.pi * w) | |
pl = pl + 2 * jnp.exp(-jnp.power(2, 2) * _x) * jnp.sin(2 * jnp.pi * w) | |
pl = pl + 3 * jnp.exp(-jnp.power(3, 2) * _x) * jnp.sin(3 * jnp.pi * w) | |
pl = pl + 4 * jnp.exp(-jnp.power(4, 2) * _x) * jnp.sin(4 * jnp.pi * w) | |
pl = pl + 5 * jnp.exp(-jnp.power(5, 2) * _x) * jnp.sin(5 * jnp.pi * w) | |
pl = pl + 6 * jnp.exp(-jnp.power(6, 2) * _x) * jnp.sin(6 * jnp.pi * w) | |
pl = pl + 7 * jnp.exp(-jnp.power(7, 2) * _x) * jnp.sin(7 * jnp.pi * w) | |
pl = pl * jnp.pi | |
# select the best expansion per element | |
normp = (ks < kl) * ps + (ks >= kl) * pl | |
# convert normalized probabilities to f(t|v,sv,a,w) | |
logp = jnp.log(normp) | |
ps = jnp.exp( | |
logp | |
+ ((a * z * sv) ** 2 - 2 * a * v * z - (v ** 2) * x) | |
/ (2 * (sv ** 2) * x + 2) | |
/ jnp.sqrt((sv ** 2) * x + 1) | |
/ (a ** 2) | |
) | |
return ps | |
def jax_wfpt_pdf_sv_sz(x, v, sv, a, lz, uz, t): | |
"""Probability density function of the WFPT distribution with normally distributed | |
drift rate and uniformly distributed starting point. | |
Args: | |
x: Reaction times. Responses to the lower bound must be negative. | |
v: Drift rate if sv == 0 or mean drift rate if sv > 0. | |
sv: Standard deviation of drift rate. [0, inf) | |
a: Value of upper bound. (0, inf). | |
lz: Lower bound on normalized starting point. (0, uz]. | |
uz: Upper bound on normalized starting point. [l, 1). | |
t: Non-decision time. [0, inf) | |
""" | |
f = jax_wfpt_pdf_sv(x, v, sv, a, lz, t) | |
f = f + jax_wfpt_pdf_sv(x, v, sv, a, (lz + uz) / 2, t) | |
f = f + jax_wfpt_pdf_sv(x, v, sv, a, uz, t) | |
return f * (uz - lz) / 6 * (uz != lz) + f * (uz == lz) | |
def jax_wfpt_pdf_sv_sz_st(x, v, sv, a, lz, uz, lt, ut): | |
"""Probability density function of the WFPT distribution with normally distributed | |
drift rate, uniformly distributed starting point, uniformly distributed nondecision | |
time. | |
Args: | |
x: Reaction times. Responses to the lower bound must be negative. | |
v: Drift rate if sv == 0 or mean drift rate if sv > 0. | |
sv: Standard deviation of drift rate. [0, inf) | |
a: Value of upper bound. (0, inf). | |
lz: Lower bound on normalized starting point. (0, uz]. | |
uz: Upper bound on normalized starting point. [l, 1). | |
lt: Lower bound on nondecision time. [0, ut]. | |
ut: Upper bound on nondecision time. [lt, inf). | |
""" | |
f = jax_wfpt_pdf_sv_sz(x, v, sv, a, lz, uz, lt) | |
f = f + jax_wfpt_pdf_sv_sz(x, v, sv, a, lz, uz, (lt + ut) / 2) | |
f = f + jax_wfpt_pdf_sv_sz(x, v, sv, a, lz, uz, ut) | |
return f * (ut - lt) / 6 * (ut != lt) + f * (ut == lt) | |
def jax_wfpt_pdf_sv_sz_st_q(x, v, sv, a, lz, uz, lt, ut, q): | |
"""Probability density function of the WFPT distribution with normally distributed | |
drift rate, uniformly distributed starting point, uniformly distributed nondecision | |
time, and uniformly distributed contaminants. | |
Args: | |
x: Reaction times. Responses to the lower bound must be negative. | |
v: Drift rate if sv == 0 or mean drift rate if sv > 0. | |
sv: Standard deviation of drift rate. [0, inf) | |
a: Value of upper bound. (0, inf). | |
lz: Lower bound on normalized starting point. (0, uz]. | |
uz: Upper bound on normalized starting point. [l, 1). | |
lt: Lower bound on nondecision time. [0, ut]. | |
ut: Upper bound on nondecision time. [lt, inf). | |
q: Contaminant probability | |
""" | |
f = jax_wfpt_pdf_sv_sz_st(x, v, sv, a, lz, uz, lt, ut) | |
p = 1 / jnp.max(jnp.abs(x)) | |
return (1 - q) * f + q * p | |
def jax_wfpt_sumlogp(x, v, sv, a, lz, uz, lt, ut, q): | |
"""Sum of log probability densities function of the WFPT distribution with normally | |
distributed drift rate, uniformly distributed starting point, uniformly distributed | |
nondecision time, and uniformly distributed contaminants. | |
Args: | |
x: Reaction times. Responses to the lower bound must be negative. | |
v: Drift rate if sv == 0 or mean drift rate if sv > 0. | |
sv: Standard deviation of drift rate. [0, inf) | |
a: Value of upper bound. (0, inf). | |
lz: Lower bound on normalized starting point. (0, uz]. | |
uz: Upper bound on normalized starting point. [l, 1). | |
lt: Lower bound on nondecision time. [0, ut]. | |
ut: Upper bound on nondecision time. [lt, inf). | |
q: Contaminant probability | |
""" | |
return jnp.sum(jnp.log(jax_wfpt_pdf_sv_sz_st_q(x, v, sv, a, lz, uz, lt, ut, q))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment