Skip to content

Instantly share code, notes, and snippets.

@ntakouris
Created September 25, 2020 14:33
Show Gist options
  • Save ntakouris/1ed2f5452b9ea4e43fcb81046e315a46 to your computer and use it in GitHub Desktop.
Save ntakouris/1ed2f5452b9ea4e43fcb81046e315a46 to your computer and use it in GitHub Desktop.
import tensorflow as tf
from functools import partial
def collate_pair(x, window):
# a, last sample of b
return x[:, 0], [x[-1, 1]]
ds_raw = tf.constant([[i, 1 if i % 3 == 0 else 0] for i in range(0, 100)])
ds = tf.data.Dataset.from_tensor_slices(ds_raw)
ds = ds.window(2).flat_map(lambda x: x.batch(2)).map(partial(collate_pair, window=2))
for a, b in ds.take(1):
print(f'a: {a} -> b: {b}')
# a: [0 1] -> b: [0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment