Skip to content

Instantly share code, notes, and snippets.

@ntakouris
Created September 25, 2020 14:44
Show Gist options
  • Save ntakouris/61ff1df369a95c457ebd2e078fa07eac to your computer and use it in GitHub Desktop.
Save ntakouris/61ff1df369a95c457ebd2e078fa07eac to your computer and use it in GitHub Desktop.
import tensorflow as tf
from functools import partial
def collate_pair(x):
return x[:-1], x[-1]
ds_raw = tf.constant([i for i in range(0, 100)])
ds = tf.data.Dataset.from_tensor_slices(ds_raw)
ds = ds.window(3).flat_map(lambda x: x.batch(6)).map(collate_pair)
for a, b in ds.take(2):
print(f'a: {a} -> b: {b}')
# a: [0 1] -> b: 2
# a: [3 4] -> b: 5
# DATA LOSS!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment