Last active
August 29, 2015 14:14
-
-
Save d1manson/f1d8907865298808941d to your computer and use it in GitHub Desktop.
Python impl. of Climer et al. 2014, a max-likelihood estimation of temporal autocorrelogram parameters (for in vivo neuronal spiking)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Adapted from the Matlab code acompanying the following paper: | |
Climer, J. R., DiTullio, R., Newman, E. L., Hasselmo, M. E., Eden, U. T. | |
(2014), Examination of rhythmicity of extracellularly recorded neurons | |
in the entorhinal cortex. Hippocampus, Epub ahead of print. | |
doi: 10.1002/hipo.22383. | |
Matlab code take from: github.com/jrclimer/mle_rhythmicity, on 02-feb-2015. | |
This python version is by DM, Feb 2015. | |
Module contents: | |
# namedtuple ClimerPHat | |
# namedtuple ClimerFit | |
# statpoisci <-- port of a Matlab stats function | |
# rhythmicity_pdf | |
# plot_rhythmicity_pdf <-- calls rhythmicity_pdf and then produces a simple plot | |
# theta_mle_climer <-- this is the main entry point for mle-fitting | |
# plot_theta_mle_climer <-- this is a wrapper around theta_mle_climer, it plots outputs | |
# count_to and repeat_ind <-- some indexing tools copied from Daniel's: | |
https://gist.github.com/d1manson/c4dc7a901f63f0286883 | |
In the comments, the abreviation [S.O.] means valid only when theta_skipping=True | |
Uses modfied version of pyswarm code for partical swarm optimsation. See: | |
https://gist.github.com/d1manson/40cbbb62a5f4ecc37bd7 | |
or original at http://pythonhosted.org/pyswarm/ | |
Note that original Matlab code used MLE after PSO, looking at Matlab it seems that | |
MLE is the Nelder-Mead algorithm, here we use L-BFGS-B instead. You can selecct | |
Nelder-Mead in scipy.optimise.minimze, but it doesnt let you specify bounds, | |
which is annoying. | |
Only limited testing has been done. Use at your own risk, but please let the | |
author know if you find bugs! | |
""" | |
from collections import namedtuple | |
from numpy import sqrt | |
from numpy import pi, cos, exp | |
from scipy.stats import chi2, norm | |
import numpy as np | |
from pso import pso # Partical swarm optimisation | |
from scipy.optimize import minimize | |
import matplotlib.pylab as plt | |
realmin = np.finfo(np.double).tiny #same as in matlab | |
ClimerPHat = namedtuple('ClimerPHat',','.join( ( | |
'tau', # Exponential falloff rate of the whole distribution (log10(sec)) | |
'b', # baseline probability | |
'c', # Falloff rate of the rhythmicity magnitude (log10(sec)) | |
'f', # Frequency of the rhymicity (Hz) | |
'r', # Rhythmicity | |
's' # Skipping [S.O.] | |
))) | |
ClimerPHat_len = 6 | |
ClimerFit = namedtuple('ClimerFit',', '.join(( | |
'theta_skipping', # records whether theta_skipping was True or False in fit request | |
'freq_trial', # Firing frequency (Hz) | |
'freq_trial_CI', # 95% confidence interval for above | |
'freq_win', # Firing frequency in each window (Hz) | |
'freq_win_CI', # 95% confidence interval for above | |
'hist', # histogram counts of lags | |
'bin_size_sec', # bin_size_sec of histogram bins | |
'LL_flat',# Log liklihood of the arrhythmic fit | |
'LL_noskip',# Log liklihood of the non-skipping fit | |
'LL_skip', # Log liklihood of full fit with skipping [S.O.] | |
'p_hat_flat', # MLE estiamtors for on rhytmicity, see ClimerPHat | |
'p_hat_noskip', #MLE estimators for no theta-skipping see ClimerPHat | |
'p_hat_skip', #MLE estimators full see ClimerPHat [S.O.] | |
'D_flat_v_noskip', # Deviance of flat versus rhythmic fit ... | |
'p_flat_v_noskip', # ... and p value | |
'D_noskip_v_skip', # Deviamce of non-skippiong versus skipping fit [S.O.]... | |
'p_noskip_v_skip', # ... and p value [S.O] | |
))) | |
def statpoisci(m, lambdahat, alpha=0.05): | |
""" ported from MATLAB's function of the same name. | |
confidence interval for Poisson lambda parameter | |
m and lambdahat should be broadcastable to the same shape. """ | |
lsum = np.asanyarray(m*lambdahat) | |
use_exact = lsum < 100 | |
lb, ub = np.empty(lsum.shape), np.empty(lsum.shape) | |
# Chi-square exact method | |
lb[use_exact] = chi2.ppf(alpha/2, 2*lsum[use_exact])/2 | |
ub[use_exact] = chi2.ppf(1-alpha/2, 2*(lsum[use_exact]+1))/2 | |
# Normal approximation | |
lb[~use_exact] = norm.ppf(alpha/2, lsum[~use_exact], sqrt(lsum[~use_exact])) | |
ub[~use_exact] = norm.ppf(1 - alpha/2, lsum[~use_exact], sqrt(lsum[~use_exact])) | |
return (lb/m, ub/m) | |
def rhythmicity_pdf(p_hat, x, x_meaning='bin_inds', | |
theta_skipping=True, normalize=True, | |
as_log=False, compute_total_L=False, | |
bin_size_sec=0.001, max_lag_sec=0.6, | |
default_phat=[np.nan,np.nan,1,1,0,0], | |
): | |
""" Parametric distribution of lags for rhythmic neurons | |
RHYTHMICITY_PDF generates the PMF of lags based on the parameters in p_hat. | |
TAKES: | |
p_hat - see ClimerPHat above for details. | |
It can be a ClimerPHat/list/tuple/1d-array or a 2d-array. | |
Iterating over it corresponds to iterating over a ClimerPHat instance. | |
Thus, p_hat[0] corresponds to the first field in ClimerPHat etc. | |
However it is permitted to be shorter than ClimerPHat, in such cases | |
additional values are added using default_phat. | |
The inidivual values may be scalars or vectors. If vectors they should | |
all be of the same length, (although you may mix vectors and scalars). | |
x,x_meaning - if x_meaning='bin_inds' then x is a full list of lags expressed | |
as bin indices. | |
if x_meaning='hist' then x is a histogram of the bin_indices, | |
this is useful in combination with compute_total_L=True, as we | |
can more efficiently sum over a histogram. | |
normalize - when true, normalizes the PDF to be a probability distribution | |
(The integral over the window is 1). If false, the value at 0 is 1. | |
as_log - when true, take logarithm of output. It is generally more efficient | |
to do this inside the function. | |
RETURNS: | |
L: The liklihood or log-likilihood of each lag. | |
See top of module for notes on authorship etc. | |
TODO: computing all the cosines is pretty slow. If you really want speed then | |
it might be worth considering having a table of pre-computed values. You would | |
round freq to your desired precision and then lookup or create the required | |
cos(2 pi f t) and cos(pi f t) arrays. It might be easier and possibly better | |
to only deal with the case where f is a vector, and check to see if there are | |
not that many distinct values. Could then construct a temproary cosine table. | |
Could do the same for sqrt(1-s), athough that isn't as pressing. | |
Note that this is slightly different to, though inspired by, the optimisation | |
implemented in the original Matlab code. | |
""" | |
# Force params to be a list of 1d vectors and/or scalars | |
# Note vectors will need to be the same length, or an exception will occur later | |
if isinstance(p_hat,np.ndarray): | |
if p_hat.ndim == 1 or min(p_hat.shape) == 1: | |
p_hat = list(p_hat) | |
elif p_hat.ndim == 2: | |
if p_hat.shape[0] > ClimerPHat_len: | |
raise Exception("Got 2D p_hat but first dim is longer than number of params") | |
p_hat = list(p_hat) | |
else: | |
raise Exception("p_hat must be 1D or 2D, not more.") | |
else: | |
p_hat = list(map(lambda p: np.asarray(p).ravel(), p_hat)) | |
# Pad with additional params if neccessary | |
if len(p_hat) < ClimerPHat_len: | |
p_hat += map(lambda x: np.asarray(x).ravel(), default_phat[len(p_hat):]) | |
# construct the namedtuple, adding an additional dimension to all arrays | |
p_hat = ClimerPHat._make(map(lambda p: p[:,np.newaxis] if p.ndim >0 else p,p_hat)) | |
if np.any(p_hat.b) < 0: | |
raise Exception("p_hat's b value(s) must be non-negative.") | |
if theta_skipping: | |
F_t = lambda t: 1.0/4.0*( \ | |
(2+2*sqrt(1-p_hat.s)-p_hat.s) \ | |
* cos(2*pi*p_hat.f*t) \ | |
+ 4*p_hat.s*cos(pi*p_hat.f*t) \ | |
+ 2 \ | |
- 2*sqrt(1-p_hat.s) \ | |
- 3*p_hat.s ) | |
else: | |
F_t = lambda t: cos(2*pi*p_hat.f*t) | |
L_t = lambda t: (1-p_hat.b) \ | |
* exp(-t*10**-p_hat.tau) \ | |
*( p_hat.r \ | |
* exp(-t*10**-p_hat.c) \ | |
* F_t(t) \ | |
+ 1) \ | |
+ p_hat.b | |
# compute likelihood lookup table, which will be of | |
# shape [m x n_bins] where m is the number of paramater sets | |
n_bins = np.ceil(max_lag_sec/bin_size_sec) | |
bin_centres = np.arange(n_bins+1)*bin_size_sec + bin_size_sec/2 | |
L_bin = L_t(bin_centres[np.newaxis,:]) | |
if normalize: | |
L_bin /= np.sum(L_bin,axis=1,keepdims=True) | |
L_bin[(L_bin < 0) | np.isnan(L_bin) | (L_bin == np.inf)] = realmin | |
if as_log: | |
L_bin = np.log(L_bin) | |
if x_meaning == 'bin_inds': | |
# lookup likelihood for each of the data points | |
return L_bin[:,x] | |
elif x_meaning == 'hist': | |
if compute_total_L is False: | |
raise Exception("Providing a hist for x only makes sense in conjuction with" \ | |
" computing the total likelihood, so set compute_total_L=True.") | |
# Multiply each bin's likelhood by the number of counts in the bin | |
# then sum over all bins | |
return np.dot(L_bin,x.ravel()) | |
else: | |
raise Exception("Unrecognised 'x_meaning'.") | |
def theta_mle_climer( times, duration, max_lag_sec=0.6, | |
theta_skipping=True,bin_size_sec=0.001, | |
lb=ClimerPHat(-1, 0, -1, 1, 0, 0), | |
ub=ClimerPHat(1, 1, 1, 13, 1, 1), | |
full_output=True): | |
""" Use maximum liklihood estimation to find rhythmicity parameters. | |
See top of module for authorship notes etc. | |
Takes: | |
times - spike times in seconds, sorted. | |
duration - trial duration in seconds | |
max_lag_sec - as it says | |
bin_size_sec - as it says | |
theta_skipping - if false we fit without using thet theta-skipping part of the model | |
full_output - if false we do not calculate histogram of freq trial/win | |
Returns a ClimerFit namedtuple. see top of module for meaning of attributes. | |
TODO: in minimize we want to be able to specify alpha =0.05 | |
""" | |
n_spikes = len(times) | |
iend = times.searchsorted(times + max_lag_sec, side='right') | |
counts_in_window = iend-np.arange(n_spikes)-1 | |
b_ind = repeat_ind(counts_in_window) | |
a_ind = count_to(counts_in_window) + 1 + b_ind | |
x = times[a_ind] - times[b_ind] | |
x_bin_inds = np.floor(x/bin_size_sec).astype(np.int32) | |
if full_output: | |
# Firing rate is the mean count per second. | |
# Assume firing is Poisson, and get confidence interval on this mean value. | |
freq_trial = float(n_spikes)/duration | |
freq_trial_CI = statpoisci(duration, freq_trial) | |
# Above was for whole trial, now do separtely for each window | |
freq_win = counts_in_window/max_lag_sec | |
freq_win_CI = statpoisci(max_lag_sec, freq_win) | |
# Climer divides the above by freq_trial to get a "multiplier" | |
# make histrogram from bin inds | |
hist = np.bincount(x_bin_inds,minlength=np.ceil(max_lag_sec/bin_size_sec)+1) | |
# Prepare the list of common inputs to rhythmicity_pdf and the main minimizer | |
kwargs_for_pdf = dict(x=hist, | |
x_meaning='hist', | |
as_log=True, | |
compute_total_L=True, | |
max_lag_sec=max_lag_sec, | |
bin_size_sec=bin_size_sec) | |
kwargs_for_minimize = dict(options={'disp':False}, | |
method='L-BFGS-B') | |
# Initial guess using particle swarm | |
p_hat_0,_ = pso( \ | |
lambda p_hat: -rhythmicity_pdf(p_hat, theta_skipping=False, **kwargs_for_pdf), | |
lb=lb[:-1], ub=ub[:-1], swarmsize=75, maxiter=100, info=False, func_takes_multiple=True) | |
# Solve for no-skipping | |
opt_res = minimize( lambda p_hat: -rhythmicity_pdf(p_hat, theta_skipping=False, **kwargs_for_pdf), | |
x0=p_hat_0, bounds=zip(lb[:-1],ub[:-1]), **kwargs_for_minimize) | |
LL_noskip, p_hat_noskip = opt_res.fun, ClimerPHat(*np.append(opt_res.x,0)) | |
# Arrhythmic fit, use p_hat_0 from above (i.e. non-skipping version) | |
opt_res = minimize( lambda p_hat: -rhythmicity_pdf(p_hat, theta_skipping=False, **kwargs_for_pdf), | |
x0=p_hat_0[:2], bounds=zip(lb[:2],ub[:2]),**kwargs_for_minimize) | |
LL_flat, p_hat_flat = opt_res.fun, ClimerPHat(*np.append(opt_res.x,[1,1,0,0])) | |
D_flat_v_noskip = 2*(LL_noskip-LL_flat) | |
p_flat_v_noskip = 1-chi2.cdf(D_flat_v_noskip, len(lb)-2) # TODO: check this is right | |
if theta_skipping: #Need to do full fit... | |
# Again, initial guess using particle swarm | |
p_hat_0,_ = pso( \ | |
lambda p_hat: -rhythmicity_pdf(p_hat, theta_skipping=True, **kwargs_for_pdf), | |
lb=lb, ub=ub, swarmsize=75, maxiter=100, info=False, func_takes_multiple=True) | |
# Again, solve, but now for full, skipping case | |
opt_res = minimize( lambda p_hat: -rhythmicity_pdf(p_hat, theta_skipping=True, **kwargs_for_pdf), | |
x0=p_hat_0, bounds=zip(lb,ub), **kwargs_for_minimize) | |
LL_skip, p_hat_skip = opt_res.fun, ClimerPHat(*opt_res.x) | |
D_noskip_v_skip = 2*(LL_skip-LL_noskip) | |
p_noskip_v_skip = 1-chi2.cdf(D_noskip_v_skip, 1) | |
# Collect everuything required for output | |
locals_ = locals() | |
return ClimerFit(**{k: locals_.get(k,None) for k in ClimerFit._fields}) | |
def plot_rhythmicity_pdf(p_hat,max_lag_sec=0.6,bin_size_sec=0.001,ax=None): | |
n_bins = np.ceil(max_lag_sec/bin_size_sec).astype(int) | |
y = rhythmicity_pdf(p_hat, | |
np.arange(n_bins), | |
normalize=False, | |
theta_skipping= False if p_hat.s == 0 else True).ravel() | |
if ax is None: | |
plt.cla() | |
plt.plot(np.arange(n_bins)*bin_size_sec + 0.5*bin_size_sec,y,'r-') | |
def plot_theta_mle_climer(times=None,duration=None,r=None,show_legend=True,**kwargs): | |
""" Calls theta_mle_climer and produces | |
a plot of the results. | |
You can provide times, duration, and kwargs for theta_mle_climer, | |
in which case we will run the fitting computation. Alternatively, | |
you can provide r, the returned ClimerFit tuple, we can then plot | |
this directly. | |
This function returns the output of theta_mle_climer. """ | |
if r is None: | |
r = theta_mle_climer(times,duration,**kwargs) | |
plt.cla() | |
bin_lefts = np.arange(len(r.hist)) * r.bin_size_sec | |
N = np.sum(r.hist) | |
plt.bar(bin_lefts,r.hist,width=r.bin_size_sec,ec='none',color=[0.7]*3) | |
plt.xlabel('time since spike (s)') | |
plt.ylabel('count') | |
plt.xlim(0,bin_lefts[-1]+r.bin_size_sec) | |
plt.hold(True) | |
fit_flat = rhythmicity_pdf(r.p_hat_flat, | |
np.arange(len(r.hist)), | |
normalize=True, | |
theta_skipping=False).ravel() | |
lg_flat, = plt.plot(bin_lefts + r.bin_size_sec/2, | |
fit_flat*N,'g-') | |
fit_noskip = rhythmicity_pdf(r.p_hat_noskip, | |
np.arange(len(r.hist)), | |
normalize=True, | |
theta_skipping=False).ravel() | |
lg_noskip, = plt.plot(bin_lefts + r.bin_size_sec/2, | |
fit_noskip*N,'r-') | |
lg = (lg_flat,lg_noskip) | |
lg_keys = ('no rhythmicity','no skipping') | |
maxy = 1.2*fit_noskip[0]*N | |
if r.theta_skipping: | |
fit_skip = rhythmicity_pdf(r.p_hat_skip, | |
np.arange(len(r.hist)), | |
normalize=True, | |
theta_skipping=True).ravel() | |
lg_skip, = plt.plot(bin_lefts + r.bin_size_sec/2, | |
fit_skip*N,'b-') | |
lg = lg + (lg_skip,) | |
lg_keys = lg_keys + ('with skipping',) | |
maxy = 1.2*fit_skip[0]*N | |
plt.ylim(0,maxy) | |
if show_legend: | |
plt.legend(lg,lg_keys) | |
plt.text(0.97,0.02,'mean freq={:0.1f}Hz'.format(r.freq_trial),va='bottom',ha='right',weight='bold',transform=plt.gca().transAxes) | |
plt.title('s={:0.2f} f={:0.1f} r={:0.2f}'.format(r.p_hat_skip.s,r.p_hat_skip.f,r.p_hat_skip.r)) | |
return r | |
def repeat_ind(n): | |
"""By example: | |
# 0 1 2 3 4 5 6 7 8 | |
n = [0, 0, 3, 0, 0, 2, 0, 2, 1] | |
res = [2, 2, 2, 5, 5, 7, 7, 8] | |
That is the input specifies how many times to repeat the given index. | |
It is equivalent to something like this : | |
hstack((zeros(n_i,dtype=int)+i for i, n_i in enumerate(n))) | |
But this version seems to be faster, and probably scales better, at | |
any rate it encapsulates a task in a functoin. | |
""" | |
if n.ndim != 1: | |
raise Exception("n is supposed to be 1d array.") | |
n_mask = n.astype(bool) | |
n_inds = np.nonzero(n_mask)[0] | |
n_inds[1:] = n_inds[1:]-n_inds[:-1] # take diff and leave 0th value in place | |
n_cumsum = np.empty(len(n)+1,dtype=int) | |
n_cumsum[0] = 0 | |
n_cumsum[1:] = np.cumsum(n) | |
ret = np.zeros(n_cumsum[-1],dtype=int) | |
ret[n_cumsum[n_mask]] = n_inds # note that n_mask is 1 element shorter than n_cumsum | |
return cumsum(ret) | |
def count_to(n): | |
"""By example: | |
# 0 1 2 3 4 5 6 7 8 | |
n = [0, 0, 3, 0, 0, 2, 0, 2, 1] | |
res = [0, 1, 2, 0, 1, 0, 1, 0] | |
That is it is equivalent to something like this : | |
hstack((arange(n_i) for n_i in n)) | |
This version seems quite a bit faster, at least for some | |
possible inputs, and at any rate it encapsulates a task | |
in a function. | |
""" | |
if n.ndim != 1: | |
raise Exception("n is supposed to be 1d array.") | |
n_mask = n.astype(bool) | |
n_cumsum = np.cumsum(n) | |
ret = np.ones(n_cumsum[-1]+1,dtype=int) | |
ret[n_cumsum[n_mask]] -= n[n_mask] | |
ret[0] -= 1 | |
return np.cumsum(ret)[:-1] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment