Last active
February 7, 2018 06:42
-
-
Save jperl/954631259eda0f81be750b67e25f9bc4 to your computer and use it in GitHub Desktop.
stack past tensor slices
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
def _stack_past(x, steps): | |
"""Stack the past data for each step. | |
Ex. x = [0, ..., 60]. steps = [10, 20] | |
Result [x, x[:-10], x[:-20]] normalized to the same shape | |
""" | |
# Sort the steps in ascending order [10, 20] | |
sorted_steps = steps.copy() | |
sorted_steps.sort() | |
largest_step = sorted_steps[-1] | |
# Include the original data | |
stacks = [x[largest_step:]] | |
for step in sorted_steps: | |
# Normalize the shapes by skipping the difference from the largest step | |
# Which are rows without enough past data | |
skip = largest_step - step | |
stacks.append(x[skip:-step]) | |
return stacks | |
def np_stack_past(x, steps): | |
stacks = _stack_past(x, steps) | |
return np.stack(stacks, axis=-1) | |
def tf_stack_past(tensor, steps): | |
stacks = _stack_past(tensor, steps) | |
return tf.stack(stacks, axis=-1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import numpy as np | |
from numpy.testing import assert_array_equal | |
logging.getLogger('tensorflow').disabled = True | |
import tensorflow as tf # noqa | |
from utils.transform.stack import np_stack_past, tf_stack_past # noqa | |
def example_matrix(): | |
# Build an example matrix | |
# [['0a', '0b', '0c'], ... ['99a','99b','99c']] | |
x = [] | |
for i in range(0, 100): | |
si = str(i) | |
x.append([si + 'a', si + 'b', si + 'c']) | |
return np.stack(x) | |
class StackTestCase(tf.test.TestCase): | |
def test_stack_past(self): | |
x = example_matrix() | |
past = np_stack_past(x, [1, 2, 10, 20]) | |
expected_first = [['20a', '20b', '20c'], # skip 19 rows that don't have enough history | |
['19a', '19b', '19c'], | |
['18a', '18b', '18c'], | |
['10a', '10b', '10c'], # | |
['0a', '0b', '0c']] # | |
# Starts at 20 since the first 19 rows don't have enough data | |
# 0 -1 -2 -10 -20 | |
expected_first = [['20a', '19a', '18a', '10a', '0a'], | |
['20b', '19b', '18b', '10b', '0b'], | |
['20c', '19c', '18c', '10c', '0c']] | |
assert_array_equal(past[0], expected_first) | |
# 0 -1 -2 -10 -20 | |
expected_last = [['99a', '98a', '97a', '89a', '79a'], | |
['99b', '98b', '97b', '89b', '79b'], | |
['99c', '98c', '97c', '89c', '79c']] | |
assert_array_equal(past[-1], expected_last) | |
with self.test_session(): | |
x_tensor = tf.constant(x, dtype=tf.string) | |
past_tensor = tf_stack_past(x_tensor, [1, 2, 10, 20]) | |
result = past_tensor.eval().astype('U13') | |
self.assertAllEqual(result[0], expected_first) | |
self.assertAllEqual(result[-1], expected_last) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment