Skip to content

Instantly share code, notes, and snippets.

@robintibor
Last active May 18, 2021 16:12
Show Gist options
  • Save robintibor/fec0ecafdf5dc5c65584c42ee58dcd58 to your computer and use it in GitHub Desktop.
Save robintibor/fec0ecafdf5dc5c65584c42ee58dcd58 to your computer and use it in GitHub Desktop.
ecog_trajs_v31_visu_reps_trained.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 4 11:02:45 2019
visualization script (amplitude and phase perturbations) for change speed paradigm (car game)
works only for test set (last xval fold)
source data: v9_realTime (see exportData_4DNN_v9_realTime.m)
based on: ecog_trajs_v9_visu.py
NEW: saves individual repetitions of perturbations for further stat. signif. tests
- limits freq range to enable more repetitions (else results take too much space)
@author: jiri
"""
# %% settings: import libraries
import numpy as np
from matplotlib import pyplot as plt # equiv. to: import matplotlib.pyplot as plt
import scipy.io as sio
import logging
log = logging.getLogger()
log.setLevel("DEBUG")
import sys
if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s %(levelname)s : %(message)s",
level=logging.DEBUG,
stream=sys.stdout,
)
print("sys args = ", sys.argv)
if len(sys.argv) == 1:
n_job = "1"
else:
n_job = sys.argv[1]
print("your input was: " + sys.argv[1])
def file_for_number(x):
return {
"1": "ALL_11_FR1_day1_xvel",
"2": "ALL_11_FR1_day1_absVel",
"3": "ALL_11_FR2_day2_xvel",
"4": "ALL_11_FR2_day2_absVel",
"5": "ALL_11_FR3_day2_xvel",
"6": "ALL_11_FR3_day2_absVel",
"7": "ALL_13_FR1_day2_xvel",
"8": "ALL_13_FR1_day2_absVel",
"9": "ALL_14_PR1_day1_xvel",
"10": "ALL_14_PR1_day1_absVel",
"11": "ALL_15_PR1_day1_xvel",
"12": "ALL_15_PR1_day1_absVel",
"13": "ALL_15_PR4_day2_xvel",
"14": "ALL_15_PR4_day2_absVel",
"15": "ALL_16_PR7_day1_xvel",
"16": "ALL_16_PR7_day1_absVel",
"17": "ALL_17_PR14_day1_xvel",
"18": "ALL_17_PR14_day1_absVel",
"19": "ALL_17_PR16_day1_xvel",
"20": "ALL_17_PR16_day1_absVel",
"21": "ALL_18_PR3_day1_xvel",
"22": "ALL_18_PR3_day1_absVel",
"23": "ALL_18_PR6_day1_xvel",
"24": "ALL_18_PR6_day1_absVel",
}.get(
x, "ALL_11_FR1_day1_xpos"
) # latter is default if x not found
fileName = file_for_number(n_job)
print("file = " + fileName)
# %% local (CPU) or remote (GPU cluster) computing
remoteComputing = True
if remoteComputing:
dir_sourceData = "/data/hammerj/sourceData/v9_realTime"
dir_outputData = "/data/hammerj/outputData/v31_strideAfter_visuReps/trained"
import torch
log.info("CUDA is avalaible? {:d}".format(torch.cuda.is_available()))
cuda = True # You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
maxTrainEpochs = 100
N_perturbations = 1000
else:
dir_sourceData = "/home/jiri/data/sourceData/v9_realTime"
dir_outputData = "/home/jiri/data/outputData/v9_realTime"
cuda = False
maxTrainEpochs = 10
N_perturbations = 10
from braindecode.torch_ext.util import np_to_var
# %% Load data: matlab cell array
import h5py
log.info("Loading data...")
with h5py.File(dir_sourceData + "/" + fileName + ".mat", "r") as h5file:
sessions = [h5file[obj_ref] for obj_ref in h5file["D"][0]]
Xs = [session["ieeg"][:] for session in sessions]
ys = [session["traj"][0] for session in sessions]
srates = [session["srate"][0, 0] for session in sessions]
# %% create datasets
from braindecode.datautil.signal_target import SignalAndTarget
# Outer added axis is the trial axis (size one always...)
datasets = [
SignalAndTarget([X.astype(np.float32)], [y.astype(np.float32)])
for X, y in zip(Xs, ys)
]
from braindecode.datautil.splitters import concatenate_sets
# only for allocation
assert len(datasets) >= 4
train_set = concatenate_sets(datasets[:-1])
valid_set = datasets[-2] # dummy variable, could be set to None
test_set = datasets[-1]
# %% create model
from braindecode.models.deep4 import Deep4Net
from torch import nn
from braindecode.torch_ext.util import set_random_seeds
from braindecode.models.util import to_dense_prediction_model
from braindecode.torch_ext.modules import Expression
set_random_seeds(seed=20170629, cuda=cuda)
# This will determine how many crops are processed in parallel
input_time_length = 1200
n_classes = 1
in_chans = train_set.X[0].shape[0]
model = Deep4Net(
in_chans=in_chans,
n_classes=1,
input_time_length=input_time_length,
final_conv_length=2,
stride_before_pool=False,
).create_network()
# remove softmax
new_model = nn.Sequential()
for name, module in model.named_children():
if name == "softmax":
break
new_model.add_module(name, module)
# lets remove empty final dimension
def squeeze_out(x):
assert x.size()[1] == 1 and x.size()[3] == 1
return x[:, 0, :, 0]
new_model.add_module("squeeze", Expression(squeeze_out))
model = new_model
to_dense_prediction_model(model)
if cuda:
model.cuda()
from copy import deepcopy
start_param_values = deepcopy(new_model.state_dict())
# %% setup optimizer -> new for each x-val fold
from torch import optim
# %% # determine output size
from braindecode.torch_ext.util import np_to_var
test_input = np_to_var(
np.ones((2, in_chans, input_time_length, 1), dtype=np.float32)
)
if cuda:
test_input = test_input.cuda()
out = model(test_input)
n_preds_per_input = out.cpu().data.numpy().shape[1]
log.info("predictor length = {:d} samples".format(n_preds_per_input))
log.info("predictor length = {:f} s".format(n_preds_per_input / srates[0]))
# crop size is: input_time_length - n_preds_per_input + 1
# print("crop size = {:d} samples".format(input_time_length - n_preds_per_input + 1))
# %% Iterator
from braindecode.torch_ext.losses import log_categorical_crossentropy
from braindecode.experiments.experiment import Experiment
from braindecode.datautil.iterators import CropsFromTrialsIterator
from braindecode.experiments.monitors import (
RuntimeMonitor,
LossMonitor,
CroppedTrialMisclassMonitor,
MisclassMonitor,
)
from braindecode.experiments.stopcriteria import MaxEpochs
import torch.nn.functional as F
import torch as th
from braindecode.torch_ext.modules import Expression
# Iterator is used to iterate over datasets both for training and evaluation
iterator = CropsFromTrialsIterator(
batch_size=32,
input_time_length=input_time_length,
n_preds_per_input=n_preds_per_input,
)
# %% monitor for correlation
from braindecode.experiments.monitors import compute_preds_per_trial_from_crops
class CorrelationMonitor1d(object):
"""
Compute correlation between 1d predictions
Parameters
----------
input_time_length: int
Temporal length of one input to the model.
"""
def __init__(self, input_time_length=None):
self.input_time_length = input_time_length
def monitor_epoch(
self,
):
return
def monitor_set(
self, setname, all_preds, all_losses, all_batch_sizes, all_targets, dataset
):
"""Assuming one hot encoding for now"""
assert (
self.input_time_length is not None
), "Need to know input time length..."
# this will be timeseries of predictions
# for each trial
# braindecode functions expect classes x time predictions
# so add fake class dimension and remove it again
preds_2d = [p[:, None] for p in all_preds]
preds_per_trial = compute_preds_per_trial_from_crops(
preds_2d, self.input_time_length, dataset.X
)
preds_per_trial = [p[0] for p in preds_per_trial]
pred_timeseries = np.concatenate(preds_per_trial, axis=0)
ys_2d = [y[:, None] for y in all_targets]
targets_per_trial = compute_preds_per_trial_from_crops(
ys_2d, self.input_time_length, dataset.X
)
targets_per_trial = [t[0] for t in targets_per_trial]
target_timeseries = np.concatenate(targets_per_trial, axis=0)
corr = np.corrcoef(target_timeseries, pred_timeseries)[0, 1]
key = setname + "_corr"
return {key: float(corr)}
# %% visualization (Kay): Phase and Amplitude perturbation
import torch
import numpy as np
from braindecode.util import wrap_reshape_apply_fn, corr
from braindecode.datautil.iterators import get_balanced_batches
class SelectiveSequential(nn.Module):
def __init__(self, to_select, modules_list):
"""
Returns intermediate activations of a network during forward pass
to_select: list of module names for which activation should be returned
modules_list: Modules of the network in the form [[name1, mod1],[name2,mod2]...)
Important: modules_list has to include all modules of the network, not only those of interest
https://discuss.pytorch.org/t/how-to-extract-features-of-an-image-from-a-trained-model/119/8
"""
super(SelectiveSequential, self).__init__()
for key, module in modules_list:
self.add_module(key, module)
self._modules[key].load_state_dict(module.state_dict())
self._to_select = to_select
def forward(self, x):
# Call modules individually and append activation to output if module is in to_select
o = []
for name, module in self._modules.items():
x = module(x)
if name in self._to_select:
o.append(x)
return o
def phase_perturbation(amps, phases, rng=np.random.RandomState()):
"""
Takes amps and phases of BxCxF with B input, C channels, F frequencies
Shifts spectral phases randomly for input and frequencies, but same for all channels
amps: Spectral amplitude (not used)
phases: Spectral phases
rng: Random Seed
Output:
amps_pert: Input amps (not modified)
phases_pert: Shifted phases
pert_vals: Absolute phase shifts
"""
noise_shape = list(phases.shape)
noise_shape[1] = 1 # Do not sample noise for channels individually
# Sample phase perturbation noise
phase_noise = rng.uniform(-np.pi, np.pi, noise_shape).astype(np.float32)
phase_noise = phase_noise.repeat(phases.shape[1], axis=1)
# Apply noise to inputs
phases_pert = phases + phase_noise
phases_pert[phases_pert < -np.pi] += 2 * np.pi
phases_pert[phases_pert > np.pi] -= 2 * np.pi
return amps, phases_pert, np.abs(phase_noise)
def phase_perturbation_chnls(amps, phases, rng=np.random.RandomState()):
"""
Takes amps and phases of BxCxF with B input, C channels, F frequencies
Shifts spectral phases randomly for input and frequencies, but same for all channels
amps: Spectral amplitude (not used)
phases: Spectral phases
rng: Random Seed
Output:
amps_pert: Input amps (not modified)
phases_pert: Shifted phases
pert_vals: Absolute phase shifts
"""
noise_shape = list(phases.shape)
# noise_shape[1] = 1 # Do not sample noise for channels individually
# Sample phase perturbation noise
phase_noise = rng.uniform(-np.pi, np.pi, noise_shape).astype(np.float32)
# phase_noise = phase_noise.repeat(phases.shape[1],axis=1)
# Apply noise to inputs
phases_pert = phases + phase_noise
phases_pert[phases_pert < -np.pi] += 2 * np.pi
phases_pert[phases_pert > np.pi] -= 2 * np.pi
return amps, phases_pert, np.abs(phase_noise)
def amp_perturbation_additive(amps, phases, rng=np.random.RandomState()):
"""
Takes amps and phases of BxCxF with B input, C channels, F frequencies
Adds additive noise to amplitudes
amps: Spectral amplitude
phases: Spectral phases (not used)
rng: Random Seed
Output:
amps_pert: Scaled amplitudes
phases_pert: Input phases (not modified)
pert_vals: Amplitude noise
"""
amp_noise = rng.normal(0, 1, amps.shape).astype(np.float32)
amps_pert = amps + amp_noise
amps_pert[amps_pert < 0] = 0
return amps_pert, phases, amp_noise
def amp_perturbation_multiplicative(amps, phases, rng=np.random.RandomState()):
"""
Takes amps and phases of BxCxF with B input, C channels, F frequencies
Adds multiplicative noise to amplitudes
amps: Spectral amplitude
phases: Spectral phases (not used)
rng: Random Seed
Output:
amps_pert: Scaled amplitudes
phases_pert: Input phases (not modified)
pert_vals: Amplitude scaling factor
"""
amp_noise = rng.normal(1, 0.02, amps.shape).astype(np.float32)
amps_pert = amps * amp_noise
amps_pert[amps_pert < 0] = 0
return amps_pert, phases, amp_noise
def correlate_feature_maps(x, y):
"""
Takes two activation matrices of the form Bx[F]xT where B is batch size, F number of filters (optional) and T time points
Returns correlations of the corresponding activations over T
Input: Bx[F]xT (x,y)
Returns: Bx[F]
"""
shape_x = x.shape
shape_y = y.shape
assert np.array_equal(shape_x, shape_y)
assert len(shape_x) < 4
x = x.reshape((-1, shape_x[-1]))
y = y.reshape((-1, shape_y[-1]))
x = (x - x.mean(axis=1, keepdims=True)) / x.std(axis=1, keepdims=True)
y = (y - y.mean(axis=1, keepdims=True)) / y.std(axis=1, keepdims=True)
tmp_corr = x * y
corr_ = tmp_corr.sum(axis=1)
# corr_ = np.zeros((x.shape[0]))
# for i in range(x.shape[0]):
# # Correlation of standardized variables
# corr_[i] = np.correlate((x[i]-x[i].mean())/x[i].std(),(y[i]-y[i].mean())/y[i].std())
return corr_.reshape(*shape_x[:-1])
def mean_diff_feature_maps(x, y):
"""
Takes two activation matrices of the form BxFxT where B is batch size, F number of filters and T time points
Returns mean difference between feature map activations
Input: BxFxT (x,y)
Returns: BxF
"""
return np.mean(x - y, axis=2)
def perturbation_correlation(
pert_fn,
diff_fn,
pred_fn,
n_layers,
inputs,
n_iterations,
batch_size=30,
seed=((2017, 7, 10)),
):
"""
Calculates phase perturbation correlation for layers in network
pred_fn: Function that returns a list of activations.
Each entry in the list corresponds to the output of 1 layer in a network
n_layers: Number of layers pred_fn returns activations for.
inputs: Original inputs that are used for perturbation [B,X,T,1]
Phase perturbations are sampled for each input individually, but applied to all X of that input
n_iterations: Number of iterations of correlation computation. The higher the better
batch_size: Number of inputs that are used for one forward pass. (Concatenated for all inputs)
"""
rng = np.random.RandomState(seed)
# Get batch indeces
batch_inds = get_balanced_batches(
n_trials=len(inputs), rng=rng, shuffle=False, batch_size=batch_size
)
# Calculate layer activations and reshape
orig_preds = [pred_fn(inputs[inds]) for inds in batch_inds]
orig_preds_layers = [
np.concatenate([orig_preds[o][l] for o in range(len(orig_preds))])
for l in range(n_layers)
]
# Compute FFT of inputs
fft_input = np.fft.rfft(inputs, n=inputs.shape[2], axis=2)
amps = np.abs(fft_input)
phases = np.angle(fft_input)
pert_corrs = [0] * n_layers
for i in range(n_iterations):
print("Iteration%d" % i)
amps_pert, phases_pert, pert_vals = pert_fn(amps, phases, rng=rng)
# Compute perturbed inputs
fft_pert = amps_pert * np.exp(1j * phases_pert)
inputs_pert = np.fft.irfft(fft_pert, n=inputs.shape[2], axis=2).astype(
np.float32
)
# Calculate layer activations for perturbed inputs
new_preds = [pred_fn(inputs_pert[inds]) for inds in batch_inds]
new_preds_layers = [
np.concatenate([new_preds[o][l] for o in range(len(new_preds))])
for l in range(n_layers)
]
for l in range(n_layers):
# Calculate correlations of original and perturbed feature map activations
preds_diff = diff_fn(
orig_preds_layers[l][:, :, :, 0], new_preds_layers[l][:, :, :, 0]
)
# Calculate feature map correlations with absolute phase perturbations
pert_corrs_tmp = wrap_reshape_apply_fn(
corr, pert_vals[:, :, :, 0], preds_diff, axis_a=(0), axis_b=(0)
)
# pert_corrs[l] += pert_corrs_tmp # line commeted out by Jiri
# code added by Jiri
# X = pert_corrs_tmp[:, 0:130, :, np.newaxis] # 4-D array (ch,freq,units,1)
X_lofr = pert_corrs_tmp[:, 2:7, :, np.newaxis].mean(
axis=1, keepdims=True
) # low freq: 0.73 - 2.57 Hz
X_beta = pert_corrs_tmp[:, 36:82, :, np.newaxis].mean(
axis=1, keepdims=True
) # beta freq: 13.2 - 30.1 Hz
X = np.concatenate((X_lofr, X_beta), axis=1)
if i == 0:
pert_corrs[l] = X
else:
pert_corrs[l] = np.concatenate(
(pert_corrs[l], X), axis=3
) # 4-D array (ch,freq,units,iters)
# pert_corrs = [pert_corrs[l]/n_iterations for l in range(n_layers)] #mean over iterations, commented out by Jiri
return pert_corrs
# %% Loss function takes predictions as they come out of the network and the targets and returns a loss
loss_function = F.mse_loss
# Could be used to apply some constraint on the models, then should be object with apply method that accepts a module
model_constraint = None
# %% Monitors log the training progress
monitors = [
LossMonitor(),
CorrelationMonitor1d(input_time_length),
RuntimeMonitor(),
]
# %% Stop criterion determines when the first stop happens
stop_criterion = MaxEpochs(maxTrainEpochs)
# %% x-validation loop
if remoteComputing:
N = len(datasets)
inds = np.arange(N)
else:
N = 6
inds = np.arange(len(datasets))[-N:]
cc_folds = np.zeros(N)
pred_vals = []
resp_vals = []
# %% dataset indices
n = 0
i_test_set = inds[-1]
i_valid_set = inds[-2]
i_train_set = inds[
:-1
] # merges valid set with train set, previously: i_train_set = inds_new[:-2]
log.info("test set = %s" % i_test_set)
log.info("valid set = %s" % i_valid_set)
log.info("train set = %s" % i_train_set)
# %% datasets
train_set = concatenate_sets(
np.array(datasets)[i_train_set]
) # also: train_set = concatenate_sets([datasets[i] for i in i_train_set])
valid_set = datasets[i_valid_set]
test_set = datasets[i_test_set]
log.info("Train set has {:d} folds".format(len(train_set.X)))
# %% re-initialize model
model.load_state_dict(deepcopy(start_param_values))
optimizer = optim.Adam(model.parameters())
# %% DNN setup & run
# exp = Experiment(model, train_set, valid_set, test_set, iterator,
# loss_function, optimizer, model_constraint,
# monitors, stop_criterion,
# remember_best_column='train_loss',
# run_after_early_stop=False, batch_modifier=None, cuda=cuda, do_early_stop=False)
exp = Experiment(
model,
train_set,
valid_set,
test_set,
iterator,
loss_function,
optimizer,
model_constraint,
monitors,
stop_criterion,
remember_best_column="train_loss",
run_after_early_stop=False,
batch_modifier=None,
cuda=cuda,
)
exp.run()
# %% visualization (Kay): Wrap Model into SelectiveSequential and set up pred_fn
assert len(list(model.children())) == len(
list(model.named_children())
) # All modules gotta have names!
modules = list(model.named_children()) # Extract modules from model
## KAY added conv_classifier
select_modules = [
"conv_spat",
"conv_2",
"conv_3",
"conv_4",
"conv_classifier",
] # Specify intermediate outputs
model_pert = SelectiveSequential(select_modules, modules) # Wrap modules
# Prediction function that is used in phase_perturbation_correlation
model_pert.eval()
pred_fn = lambda x: [
layer_out.data.numpy()
for layer_out in model_pert.forward(
torch.autograd.Variable(torch.from_numpy(x)).float()
)
]
# Gotta change pred_fn a bit for cuda case
if cuda:
model_pert.cuda()
pred_fn = lambda x: [
layer_out.data.cpu().numpy()
for layer_out in model_pert.forward(
torch.autograd.Variable(torch.from_numpy(x)).float().cuda()
)
]
perm_X = np.expand_dims(train_set.X, 3) # Input gotta have dimension BxCxTx1
# %% save model for Robin
torch.save(model, dir_outputData + "/" + fileName + "_model")
# %% reshape the array to 2*n_preds_per_input samples (2-s) long window
wSize = 682 # 2*n_preds_per_input #smallest possible=685 (empirically found)
train_set_new = np.asarray(train_set.X)
nWindows = int(train_set_new.shape[2] / wSize)
train_set_new = train_set_new[:, :, : wSize * int(train_set_new.shape[2] / wSize)]
shape_tmp = train_set_new.shape
## CHECK ORDER OF reshape (compareReshaping.m)
# sio.savemat(dir_outputData + '/' + 'train_set_folds' + '.mat', {'train_set_new':train_set_new})
## EITHER wSize first or second after
train_set_new = train_set_new.reshape(shape_tmp[0], shape_tmp[1], nWindows, wSize)
train_set_new = train_set_new.transpose(0, 2, 1, 3)
train_set_new = train_set_new.reshape(
shape_tmp[0] * nWindows, shape_tmp[1], wSize, 1
)
# sio.savemat(dir_outputData + '/' + 'train_set_new' + '.mat', {'train_set_new':train_set_new})
# %% visualization (Kay): Run phase and amplitude perturbations
log.info("visualization: perturbation computation ...")
# phase_pert_corrs = perturbation_correlation(phase_perturbation_chnls, correlate_feature_maps, pred_fn,4,perm_X,N_perturbations,batch_size=2000)
# amp_pert_corrs = perturbation_correlation(amp_perturbation_additive, mean_diff_feature_maps, pred_fn,4,perm_X,N_perturbations,batch_size=2000)
phase_pert_corrs = perturbation_correlation(
phase_perturbation_chnls,
correlate_feature_maps,
pred_fn,
5,
train_set_new,
N_perturbations,
batch_size=200,
)
# phase_pert_mdiff = perturbation_correlation(phase_perturbation_chnls, mean_diff_feature_maps, pred_fn,5,train_set_new,N_perturbations,batch_size=200)
amp_pert_mdiff = perturbation_correlation(
amp_perturbation_additive,
mean_diff_feature_maps,
pred_fn,
5,
train_set_new,
N_perturbations,
batch_size=200,
)
# %% save perturbation over layers
# freqs = np.fft.rfftfreq(perm_X.shape[2],d=1/250.)
freqs = np.fft.rfftfreq(train_set_new.shape[2], d=1 / 250.0)
for l in range(len(phase_pert_corrs)):
layer_cc = phase_pert_corrs[l]
sio.savemat(
dir_outputData
+ "/"
+ fileName
+ "_phiPrtCC"
+ "_layer{:d}".format(l)
+ "_fold{:d}".format(n)
+ ".mat",
{"layer_cc": layer_cc},
)
# layer_cc = phase_pert_mdiff[l]
# sio.savemat(dir_outputData + '/' + fileName + '_phiPrtMD' + '_layer{:d}'.format(l) + '_fold{:d}'.format(n) + '.mat', {'layer_cc':layer_cc})
layer_cc = amp_pert_mdiff[l]
sio.savemat(
dir_outputData
+ "/"
+ fileName
+ "_ampPrtMD"
+ "_layer{:d}".format(l)
+ "_fold{:d}".format(n)
+ ".mat",
{"layer_cc": layer_cc},
)
# %% plot learning curves
if not remoteComputing:
f, axarr = plt.subplots(2, figsize=(15, 15))
exp.epochs_df.loc[:, ["train_loss", "valid_loss", "test_loss"]].plot(
ax=axarr[0], title="loss function"
)
exp.epochs_df.loc[:, ["train_corr", "valid_corr", "test_corr"]].plot(
ax=axarr[1], title="correlation"
)
plt.savefig(
dir_outputData + "/" + fileName + "_fig_lc_fold{:d}.png".format(n),
bbox_inches="tight",
)
# %% evaluation
all_preds = []
all_targets = []
dataset = test_set
for batch in exp.iterator.get_batches(dataset, shuffle=False):
preds, loss = exp.eval_on_batch(batch[0], batch[1])
all_preds.append(preds)
all_targets.append(batch[1])
preds_2d = [p[:, None] for p in all_preds]
preds_per_trial = compute_preds_per_trial_from_crops(
preds_2d, input_time_length, dataset.X
)[0][0]
ys_2d = [y[:, None] for y in all_targets]
targets_per_trial = compute_preds_per_trial_from_crops(
ys_2d, input_time_length, dataset.X
)[0][0]
assert preds_per_trial.shape == targets_per_trial.shape
# %% save values: CC, pred, resp
exp.epochs_df.to_csv(
dir_outputData + "/" + fileName + "_epochs_fold{:d}.csv".format(n),
sep=",",
header=False,
)
# exp.epochs_df.to_excel(dir_outputData + '/' + fileName + '_epochs_fold{:d}.xls'.format(n))
cc_folds[n] = np.corrcoef(preds_per_trial, targets_per_trial)[0, 1]
pred_vals.append(preds_per_trial)
resp_vals.append(targets_per_trial)
# %% plot predicted trajectory
if not remoteComputing:
plt.figure(figsize=(32, 12))
t = np.arange(preds_per_trial.shape[0]) / srates[n]
plt.plot(t, preds_per_trial)
plt.plot(t, targets_per_trial)
plt.legend(("Predicted", "Actual"), fontsize=14)
plt.title("Fold = {:d}, CC = {:f}".format(n, cc_folds[n]))
plt.xlabel("time [s]")
plt.savefig(
dir_outputData + "/" + fileName + "_fig_predResp_fold{:d}.png".format(n),
bbox_inches="tight",
)
log.info("x-validation loop = " + str(n) + " done!")
log.info("-----------------------------------------")
# %% save CCs
# np.save(dir_outputData + '/' + fileName + '_cc_folds', cc_folds)
sio.savemat(
dir_outputData + "/" + fileName + "_cc_folds" + ".mat", {"cc_folds": cc_folds}
)
# np.save(dir_outputData + '/' + fileName + '_pred_vals', pred_vals)
sio.savemat(
dir_outputData + "/" + fileName + "_pred_vals" + ".mat",
{"pred_vals": pred_vals},
)
# np.save(dir_outputData + '/' + fileName + '_resp_vals', resp_vals)
sio.savemat(
dir_outputData + "/" + fileName + "_resp_vals" + ".mat",
{"resp_vals": resp_vals},
)
sio.savemat(
dir_outputData + "/" + fileName + "_freqs" + ".mat", {"freqs": freqs[0:130]}
)
log.info("job: " + fileName + " done!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment