Last active
November 16, 2017 21:25
-
-
Save jmxpearson/5114c32ab87eba46442a504b643e0a56 to your computer and use it in GitHub Desktop.
lfads processing scripts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# take session-by-session output from lfads run in posterior sample and average | |
# mode, reassemble train and test trials, trim back to correct length (they | |
# were padded for batching), and bundle up in a single hdf5 file like | |
# the one used for behavior | |
import numpy as np | |
import h5py | |
import os | |
lfads_out_fpath = os.path.expanduser('~/data/penaltykick/model_data/lfads/fitted/') | |
lfads_in_fpath = os.path.expanduser('~/data/penaltykick/model_data/lfads/') | |
out_fpath = os.path.expanduser('~/data/penaltykick/model_data/') | |
out_fname = 'lfads_output.hdf5' | |
stem = 'pk_lfads_' | |
# this will be the eventual output file | |
out_hf = h5py.File(out_fpath + out_fname, 'w') | |
# open behavior file (to get trial lengths) | |
beh_hf = h5py.File(os.path.expanduser('~/data/penaltykick/model_data/compiled_penalty_kick_wspikes_wgaze_resplined.hdf5'), 'r') | |
sessions = [key for key in beh_hf.keys() if 'sess' in key] | |
for this_sess in sessions: | |
print(this_sess) | |
this_beh = beh_hf[this_sess] | |
# open input file to lfads (to get train_test split) | |
lfads_in_fname = stem + this_sess | |
# open output file from lfads (to get train_test split) | |
lfads_out_stem = 'model_runs_' | |
lfads_out_suff = '_posterior_sample_and_average' | |
lfads_out_train_fname = lfads_out_stem + this_sess + '_train' + lfads_out_suff | |
lfads_out_valid_fname = lfads_out_stem + this_sess + '_valid' + lfads_out_suff | |
try: | |
lfads_infile = h5py.File(lfads_in_fpath + lfads_in_fname, 'r') | |
lfads_outfile_train = h5py.File(lfads_out_fpath + lfads_out_train_fname, 'r') | |
lfads_outfile_valid = h5py.File(lfads_out_fpath + lfads_out_valid_fname, 'r') | |
except: | |
continue | |
train_inds = list(lfads_infile['trial_list'][lfads_infile['is_train'].value]) | |
valid_inds = list(lfads_infile['trial_list'][~lfads_infile['is_train'].value]) | |
for int_trial, is_train in zip(lfads_infile['trial_list'], lfads_infile['is_train']): | |
this_trial = 'trial' + str(int_trial) | |
curr_path = this_sess + '/' + this_trial | |
T_beh = this_beh[this_trial].shape[1] | |
if is_train: | |
idx = train_inds.index(int_trial) | |
factors = lfads_outfile_train['factors'][idx, :T_beh] | |
rates = lfads_outfile_train['output_dist_params'][idx, :T_beh] | |
control = lfads_outfile_train['controller_outputs'][idx, :T_beh] | |
else: | |
idx = valid_inds.index(int_trial) | |
factors = lfads_outfile_valid['factors'][idx, :T_beh] | |
rates = lfads_outfile_valid['output_dist_params'][idx, :T_beh] | |
control = lfads_outfile_valid['controller_outputs'][idx, :T_beh] | |
out_hf[curr_path + '/factors'] = factors | |
out_hf[curr_path + '/rates'] = rates | |
out_hf[curr_path + '/control'] = control | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# convert spiking data to format required by lfads | |
import numpy as np | |
import h5py | |
import os | |
from utils import write_datasets | |
from synth_data.synthetic_data_utils import add_alignment_projections # needed to add __init__.py in synth_data | |
np.random.seed(12345) | |
# constants, etc. | |
DT = 0.15 # 15ms bins | |
TRAIN_FRAC = 0.8 # fraction of train/test split | |
NPCS = 10 # for alignment: MUST BE THE SAME AS EVENTUAL FACTOR NUMBER! | |
NTRIALS = 500 # for alignment | |
NTIMES = 800 # for alignment | |
fpath = os.path.expanduser('~/data/penaltykick/model_data/lfads/') | |
stem = 'pk_lfads' | |
hf = h5py.File(os.path.expanduser('~/data/penaltykick/model_data/compiled_penalty_kick_wspikes_wgaze_resplined.hdf5')) | |
sessions = [key for key in hf.keys() if 'sess' in key] | |
spks = hf['spikes'] | |
datasets = {} | |
for sess_name in sessions: | |
sess = hf[sess_name] | |
if sess_name in spks: | |
sess_spks = spks[sess_name] | |
else: | |
continue | |
spk_list = [] | |
trial_list = [] # list of trials with spikes | |
maxT = 0 | |
maxU = 0 | |
minN = 128000 | |
for trial_name in sess.keys(): | |
trial = sess[trial_name] | |
info = trial.attrs | |
if info['Spikes'] and info['Complete'] and info['GameMode'] == 1: | |
if not info['ReplayOldBarData'] or np.isnan(info['ReplayOldBarData']): | |
if trial_name in sess_spks: | |
trial_spks = sess_spks[trial_name].value | |
maxT = max(maxT, trial_spks.shape[1]) | |
maxU = max(maxU, trial_spks.shape[0]) | |
spk_list.append(trial_spks) | |
trial_list.append(int(trial_name[len('trial'):])) | |
ntrials = len(spk_list) | |
if ntrials == 0: | |
continue | |
# do train/test split | |
trial_split = np.random.rand(ntrials) <= TRAIN_FRAC | |
trial_arr = np.array(trial_list) | |
ntrain = np.sum(trial_split) | |
nval = ntrials - ntrain | |
# allocate train and validation data | |
# needed to use NTIMES, since lfads needs all trials to be same length | |
train_data = np.zeros((ntrain, NTIMES, maxU), dtype='int64') | |
valid_data = np.zeros((nval, NTIMES, maxU), dtype='int64') | |
condition_labels_train = np.zeros((ntrain,)) | |
condition_labels_valid = np.zeros((nval,)) | |
nt, nv = 0, 0 | |
for ts, ss in zip(trial_split, spk_list): | |
if ts: | |
train_data[nt, :ss.shape[1], :ss.shape[0]] = ss.T | |
nt += 1 | |
else: | |
valid_data[nv, :ss.shape[1], :ss.shape[0]] = ss.T | |
nv += 1 | |
data = ({'train_data': train_data, | |
'valid_data': valid_data, | |
'dt': DT, | |
'trial_list': trial_arr, # integer (1-based) indices of trials with spike data | |
'is_train': trial_split, # True if trial is for training (keyed to trial_list) | |
'condition_labels_train': condition_labels_train, | |
'condition_labels_valid': condition_labels_valid}) | |
minN = min(minN, train_data.shape[0]) | |
datasets[sess_name] = data | |
# need to write P_sxn matrix with unit indices for which were recorded in each session | |
totU = np.sum([d['train_data'].shape[2] for _, d in datasets.items()]) | |
accU = 0 | |
for _, data in datasets.items(): | |
U = data['train_data'].shape[2] | |
data['P_sxn'] = np.roll(np.eye(U, totU), accU, axis=1) | |
accU += U | |
datasets = add_alignment_projections(datasets, npcs=NPCS, nsamples=NTRIALS, ntime=NTIMES) | |
write_datasets(fpath, stem, datasets) | |
hf.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment