Last active
February 12, 2021 11:10
-
-
Save cjayb/d7fa004782592b5557e9fd80fb96be0f to your computer and use it in GitHub Desktop.
Demonstrating effects of 'baseline_renormalize' in hnn-core
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
TSTOP = 5000. | |
recalc = True | |
m = 3.4770508e-3 | |
b = -51.231085 | |
# these values were fit over the range [750., 5000] | |
t1 = 750. | |
m1 = 1.01e-4 | |
b1 = -48.412078 | |
times = np.arange(0, 2500) | |
data = np.zeros((len(times), )) | |
N_pyr = 100 | |
dpl_offset = N_pyr * -49.0502 | |
data[times <= 37.] -= dpl_offset | |
data[(times > 37.) & (times < t1)] -= N_pyr * \ | |
(m * times[(times > 37.) & (times < t1)] + b) | |
data[times >= t1] -= N_pyr * \ | |
(m1 * times[times >= t1] + b1) | |
plt.plot(times, 1e-6 * data) | |
plt.xlabel('ms') | |
plt.ylabel('nAm') | |
plt.suptitle('Piecewise-linear function applied in `baseline_renormalize`') | |
import os.path as op | |
import hnn_core | |
from hnn_core import simulate_dipole, read_params, Network | |
hnn_core_root = op.dirname(hnn_core.__file__) | |
params_fname = op.join(hnn_core_root, 'param', 'default.json') | |
params = read_params(params_fname) | |
params['tstop'] = TSTOP | |
net = Network(params) | |
if recalc: | |
dpl_nonorm = simulate_dipole(net, postproc=False)[0] | |
dpl_nonorm.write(f'dpl_nonorm_{TSTOP:.0f}s.txt') | |
else: | |
from hnn_core import read_dipole | |
dpl_nonorm= read_dipole(f'dpl_nonorm_{TSTOP:.0f}s.txt') | |
times = dpl_nonorm.times | |
dpl_renorm = dpl_nonorm.copy() | |
dpl_renorm.data['L5'][times <= 37.] -= dpl_offset | |
dpl_renorm.data['L5'][(times > 37.) & (times < t1)] -= N_pyr * \ | |
(m * times[(times > 37.) & (times < t1)] + b) | |
dpl_renorm.data['L5'][times >= t1] -= N_pyr * \ | |
(m1 * times[times >= t1] + b1) | |
fig, axs = plt.subplots(1, 1) | |
axs.plot(times, 1e-6 * dpl_nonorm.data['L5']) | |
axs.plot(times, 1e-6 * dpl_renorm.data['L5'], linewidth=3) | |
dpl_postproc = dpl_nonorm.copy() | |
N_pyr_x = net.params['N_pyr_x'] | |
N_pyr_y = net.params['N_pyr_y'] | |
window_len = net.params['dipole_smooth_win'] | |
fctr = 1 # net.params['dipole_scalefctr'] | |
dpl_postproc.post_proc(N_pyr_x, N_pyr_y, window_len, fctr) | |
axs.plot(times, dpl_postproc.data['L5'], linestyle='--') | |
# demean | |
axs.plot(times, 1e-6 * (dpl_nonorm.data['L5'] - dpl_nonorm.data['L5'].mean()), | |
linestyle=':') | |
axs.set_xlabel('ms') | |
axs.set_ylabel('nAm') | |
axs.legend(['postproc=False', 'baseline_renormalize', 'BR + smooth', | |
'No postproc + demean']) | |
fig.suptitle('Comparison of different normalisation methods') | |
# burn-in duration? see when 10 ms moving average stabilizes | |
mean_win_len = int(10e-3 * dpl_nonorm.sfreq) # 10 ms in samples | |
calc_over = np.arange(0, (100 - 10 // 2) * 1e-3 * dpl_nonorm.sfreq, | |
dtype=np.int) # in samps | |
nonorm_mean_L5 = np.zeros((len(calc_over), )) | |
nonorm_mean_L2 = np.zeros((len(calc_over), )) | |
for sind in calc_over: | |
nonorm_mean_L5[sind] = dpl_nonorm.data[ | |
'L5'][sind:sind + mean_win_len].mean() | |
nonorm_mean_L2[sind] = dpl_nonorm.data[ | |
'L2'][sind:sind + mean_win_len].mean() | |
fig, axs = plt.subplots(1, 2) | |
axs[0].plot(1e3 * calc_over / dpl_nonorm.sfreq, 1e-6 * nonorm_mean_L5) | |
axs[0].set_xlabel('ms') | |
axs[0].set_ylabel('nAm') | |
axs[0].set_xticks(np.arange(0, 100, 10)) | |
axs[0].grid('on') | |
axs[0].legend(['L5']) | |
axs[1].plot(1e3 * calc_over / dpl_nonorm.sfreq, 1e-6 * nonorm_mean_L2) | |
axs[1].set_xlabel('ms') | |
axs[1].set_ylabel('nAm') | |
axs[1].set_xticks(np.arange(0, 100, 20)) | |
axs[1].grid('on') | |
axs[1].legend(['L2']) | |
fig.suptitle('Moving 10 ms average over blank simulation') | |
# burn-in artefact over in about 30 ms, after which there's a long | |
# mono-exponential rise in current to an asymptote of around -4.8 pAm (L5) | |
# note that this is for 100 cells, so assuming homogeneous contribution from | |
# all cells (NOT TESTED!), we get 48 fAm / L5 cell | |
from scipy.optimize import curve_fit | |
def fitfun(x, a, b, c): | |
return a * np.exp(-b * x) + c | |
N_pyr = 100 | |
fit_start = 30 | |
ffig, fax = plt.subplots(1, 2, figsize=(8, 4)) | |
inds = np.where(dpl_nonorm.times > fit_start) | |
p0 = {'L5': (-0.3e3 / N_pyr, 2e-3, -4.8e3 / N_pyr), | |
'L2': (0.3 / N_pyr, 2.25e-2, 4.45 / N_pyr)} | |
fit_params = dict() | |
for idx, layer in enumerate(['L5', 'L2']): | |
xdata = dpl_nonorm.times[inds] | |
ydata = dpl_nonorm.data[layer][inds] / N_pyr | |
fax[idx].plot(xdata, ydata) | |
# manual_fit = fitfun(xdata, *p0[layer]) | |
# fax[idx].plot(xdata, manual_fit, linestyle="--") | |
popt, pcov = curve_fit(fitfun, xdata, ydata, p0=p0[layer]) | |
legstr = 'fit: {:.2e} x \n exp( {:.2e} t ) + {:.2e}'.format(*popt) | |
fit_params[layer] = '{:.4e} * np.exp({:.4e} * t) + {:.4e}'.format(*popt) | |
fax[idx].plot(xdata, fitfun(xdata, *popt), linestyle='--', label=legstr) | |
fax[idx].set_title(layer) | |
fax[idx].set_xlabel('Time (ms)') | |
fax[idx].legend() | |
fax[idx].ticklabel_format(style='sci', axis='y', scilimits=(-2, 2)) | |
fax[0].set_ylabel('fAm (sum of 100 neurons)') | |
ffig.suptitle(f'Exponential fits to blank simulation after {fit_start:.0f} ms') | |
print('Exponential fits in units of fAm and ms:') | |
print(f"L5: {fit_params['L5']}") | |
print(f"L2: {fit_params['L2']}") | |
# Exponential fits in units of fAm and ms: | |
# L5: -3.6498e+00 * np.exp(1.9647e-03 * t) + -4.8023e+01 | |
# L2: 2.8063e-03 * np.exp(1.1149e-02 * t) + 4.4301e-02 | |
# Let's try with a tiny network! | |
params.update({'N_pyr_x': 3, | |
'N_pyr_y': 3}) | |
net_tiny = Network(params) | |
dpl_nonorm_tiny = simulate_dipole(net_tiny, postproc=False, | |
record_isoma=True, record_vsoma=True)[0] | |
N_pyr = 9 | |
fit_start = 30 | |
ffig, fax = plt.subplots(1, 2, figsize=(8, 4)) | |
inds = np.where(dpl_nonorm_tiny.times > fit_start) | |
p0 = {'L5': (-0.3e3 / N_pyr, 2e-3, -4.8e3 / N_pyr), | |
'L2': (0.3 / N_pyr, 2.25e-2, 4.45 / N_pyr)} | |
fit_params = dict() | |
for idx, layer in enumerate(['L5', 'L2']): | |
xdata = dpl_nonorm_tiny.times[inds] | |
ydata = dpl_nonorm_tiny.data[layer][inds] / N_pyr | |
fax[idx].plot(xdata, ydata) | |
# manual_fit = fitfun(xdata, *p0[layer]) | |
# fax[idx].plot(xdata, manual_fit, linestyle="--") | |
popt, pcov = curve_fit(fitfun, xdata, ydata, p0=p0[layer]) | |
legstr = 'fit: {:.2e} x \n exp( {:.2e} t ) + {:.2e}'.format(*popt) | |
fit_params[layer] = '{:.4e} * np.exp({:.4e} * t) + {:.4e}'.format(*popt) | |
fax[idx].plot(xdata, fitfun(xdata, *popt), linestyle='--', label=legstr) | |
fax[idx].set_title(layer) | |
fax[idx].set_xlabel('Time (ms)') | |
fax[idx].legend() | |
fax[idx].ticklabel_format(style='sci', axis='y', scilimits=(-2, 2)) | |
fax[0].set_ylabel('fAm (sum of 100 neurons)') | |
ffig.suptitle(f'Exponential fits to blank simulation after {fit_start:.0f} ms' | |
' (tiny network)') | |
print('Exponential fits in units of fAm and ms (tiny network):') | |
print(f"L5: {fit_params['L5']}") | |
print(f"L2: {fit_params['L2']}") | |
# Exponential fits in units of fAm and ms (tiny network): | |
# L5: -3.6498e+00 * np.exp(1.9647e-03 * t) + -4.8023e+01 | |
# L2: 2.8055e-03 * np.exp(1.1139e-02 * t) + 4.4301e-02 | |
# Thankfully, the fits are identical for a network of 100 and 9 cells in each | |
# layer. | |
# let's look at the somatic currents and voltages | |
vsoma = net_tiny.cell_response.vsoma[0] | |
isoma = net_tiny.cell_response.isoma[0] | |
vitimes = net_tiny.cell_response.times | |
L2_pyr_gid = net_tiny.gid_ranges['L2_pyramidal'][0] | |
L5_pyr_gid = net_tiny.gid_ranges['L5_pyramidal'][0] | |
visoma = {'Voltage': (vsoma[L5_pyr_gid], vsoma[L2_pyr_gid])} | |
# master is broken: isoma is a dict with keys 'soma_gabaa' and 'soma_gabab'? | |
# 'Current': (isoma[L5_pyr_gid]['soma_gabaa'], | |
# isoma[L2_pyr_gid]['soma_gabaa'])} | |
fig, axs = plt.subplots(1, 2) | |
axs = [axs] # only plot voltage, current measure fubar | |
for col, (measure, ydata_tuple) in enumerate(visoma.items()): | |
for row, (ydata, cell_type) in enumerate(zip(ydata_tuple, ['L5', 'L2'])): | |
axs[col][row].plot(vitimes, ydata, label=cell_type) | |
axs[col][row].axhline(ydata[-1], linestyle='--', color='red', | |
label=f'{ydata[-1]:.5e}') | |
axs[col][row].grid(axis='y') | |
axs[col][row].legend() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Output plots: