Skip to content

Instantly share code, notes, and snippets.

@ntakouris
Created September 25, 2020 15:16
Show Gist options
  • Save ntakouris/41f6c076fe7d92b50aefc390ddf13a92 to your computer and use it in GitHub Desktop.
Save ntakouris/41f6c076fe7d92b50aefc390ddf13a92 to your computer and use it in GitHub Desktop.
import tensorflow as tf
from functools import partial
def get_ds(from_i, to_i):
ds = tf.data.Dataset.range(from_i, to_i)
ds = ds.flat_map(lambda x: tf.data.Dataset.range(x*3, (x + 1 )* 3))
return ds
datasets = [get_ds(i, i+1) for i in range(0, 5)]
for idx, ds in enumerate(datasets):
print(f'idx {idx}')
for sample in ds:
print(f' {sample}')
"""
idx 0
0
1
2
idx 1
3
4
5
idx 2
6
7
8
idx 3
9
10
11
idx 4
12
13
14
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment