Skip to content

Instantly share code, notes, and snippets.

@ntakouris
Created September 25, 2020 14:50
Show Gist options
  • Save ntakouris/876eb4d1b4be4bfe26143aeded1f770e to your computer and use it in GitHub Desktop.
Save ntakouris/876eb4d1b4be4bfe26143aeded1f770e to your computer and use it in GitHub Desktop.
import tensorflow as tf
from functools import partial
ds_raw = tf.constant([i for i in range(0, 100)])
ds = tf.data.Dataset.from_tensor_slices(ds_raw)
ds_a = ds.window(2).flat_map(lambda x: x.batch(2))
# for b, skip 1 and keep only last element
ds_b = ds.skip(1).window(2).flat_map(lambda x: x.batch(2)).map(lambda x: x[-1])
for a, b in tf.data.Dataset.zip((ds_a, ds_b)):
print(f'a: {a} -> b: {b}')
# a: [0 1] -> b: 2
# a: [2 3] -> b: 4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment