Skip to content

Instantly share code, notes, and snippets.

@ntakouris
Created September 25, 2020 15:20
Show Gist options
  • Save ntakouris/ec791fce950576d8323e2853ed5e5e95 to your computer and use it in GitHub Desktop.
Save ntakouris/ec791fce950576d8323e2853ed5e5e95 to your computer and use it in GitHub Desktop.
def get_ds(from_i, to_i):
ds = tf.data.Dataset.range(from_i, to_i)
ds = ds.flat_map(lambda x: tf.data.Dataset.range(x * 3, (x + 1 ) * 3))
return ds
def get_windowed_ds(i):
ds_from = get_ds(i, i+1)
ds_to = get_ds(i+1, i+2)
ds_concat = ds_from.concatenate(ds_to)
ds_a = ds_concat.window(2).flat_map(lambda x: x.batch(2))
ds_b = ds_concat.skip(1).window(2).flat_map(lambda x: x.batch(2)).map(lambda x: x[-1])
return tf.data.Dataset.zip((ds_a, ds_b))
ds_final = tf.data.Dataset.range(5) \
.interleave(lambda i: get_windowed_ds(i))
for a, b in ds_final:
print(f'a: {a} -> b: {b}')
"""
a: [0 1] -> b: 2
a: [3 4] -> b: 5
a: [2 3] -> b: 4
a: [5 6] -> b: 7
a: [4 5] -> b: 5
a: [7 8] -> b: 8
a: [6 7] -> b: 8
a: [ 9 10] -> b: 11
a: [8 9] -> b: 10
a: [11 12] -> b: 13
a: [10 11] -> b: 11
a: [13 14] -> b: 14
a: [12 13] -> b: 14
a: [14 15] -> b: 16
a: [16 17] -> b: 17
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment