Last active
December 9, 2019 08:26
-
-
Save dsalaj/dc18edc82053df35087f2ab3026f61e7 to your computer and use it in GitHub Desktop.
Crossing Threshold Encoding of pixel values to spikes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def find_onset_offset(y, threshold): | |
""" | |
Given the input signal `y` with samples, | |
find the indices where `y` increases and descreases through the value `threshold`. | |
Return stacked binary arrays of shape `y` indicating onset and offset threshold crossings. | |
`y` must be 1-D numpy arrays. | |
""" | |
if threshold == 1: | |
equal = y == threshold | |
transition_touch = np.where(equal)[0] | |
touch_spikes = np.zeros_like(y) | |
touch_spikes[transition_touch] = 1 | |
return np.expand_dims(touch_spikes, axis=0) | |
else: | |
# Find where y crosses the threshold (increasing). | |
lower = y < threshold | |
higher = y >= threshold | |
transition_onset = np.where(lower[:-1] & higher[1:])[0] | |
transition_offset = np.where(higher[:-1] & lower[1:])[0] | |
onset_spikes = np.zeros_like(y) | |
offset_spikes = np.zeros_like(y) | |
onset_spikes[transition_onset] = 1 | |
offset_spikes[transition_offset] = 1 | |
return np.stack((onset_spikes, offset_spikes)) | |
def get_data_dict(batch_size, type='train'): | |
''' | |
Generate the dictionary to be fed when running a tensorflow op. | |
:param batch_size: | |
:param test: | |
:return: | |
''' | |
if type == 'test': | |
input_px, target_oh = mnist.test.next_batch(batch_size, shuffle=False) | |
elif type == 'validation': | |
input_px, target_oh = mnist.validation.next_batch(batch_size) | |
elif type == 'train': | |
input_px, target_oh = mnist.train.next_batch(batch_size) | |
else: | |
raise ValueError("Wrong data group: " + str(type)) | |
target_num = np.argmax(target_oh, axis=1) | |
if FLAGS.n_repeat > 1: | |
input_px = np.repeat(input_px, FLAGS.n_repeat, axis=1) | |
if FLAGS.crs_thr: | |
# GENERATE THRESHOLD CROSSING SPIKES | |
thrs = np.linspace(0, 1, FLAGS.n_in // 2) # number of input neurons determins the resolution | |
spike_stack = [] | |
for img in input_px: # shape img = (784) | |
Sspikes = None | |
for thr in thrs: | |
if Sspikes is not None: | |
Sspikes = np.concatenate((Sspikes, find_onset_offset(img, thr))) | |
else: | |
Sspikes = find_onset_offset(img, thr) | |
Sspikes = np.array(Sspikes) # shape Sspikes = (31, 784) | |
Sspikes = np.swapaxes(Sspikes, 0, 1) | |
spike_stack.append(Sspikes) | |
spike_stack = np.array(spike_stack) | |
# add output cue neuron, and expand time for two image rows (2*28) | |
out_cue_duration = 2 * 28 * FLAGS.n_repeat | |
spike_stack = np.lib.pad(spike_stack, ((0, 0), (0, out_cue_duration), (0, 1)), 'constant') | |
# output cue neuron fires constantly for these additional recall steps | |
spike_stack[:, -out_cue_duration:, -1] = 1 | |
else: | |
spike_stack = input_px | |
spike_stack = np.expand_dims(spike_stack, axis=2) | |
# # match input dimensionality (add inactive output cue neuron) | |
# spike_stack = np.lib.pad(spike_stack, ((0, 0), (0, 0), (0, 1)), 'constant') | |
# transform target one hot from batch x classes to batch x time x classes | |
data_dict = {input_spikes: spike_stack, targets: target_num} | |
return data_dict, input_px |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment