Skip to content

Instantly share code, notes, and snippets.

@anshajk
Created June 20, 2020 11:25
Show Gist options
  • Save anshajk/7819e456ae9ea954d3cb3126865c2741 to your computer and use it in GitHub Desktop.
Save anshajk/7819e456ae9ea954d3cb3126865c2741 to your computer and use it in GitHub Desktop.
A small function for creating a windowed dataset for sequential networks using tensorflow 2.x
def windowed_dataset(series, window_size, batch_size, shuffle_buffer):
"""Function for creating a windowed dataset for sequence training"""
dataset = tf.data.Dataset.from_tensor_slices(series)
dataset = dataset.window(window_size + 1, shift=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(window_size + 1))
dataset = dataset.shuffle(shuffle_buffer).map(lambda window: (window[:-1], window[-1]))
dataset = dataset.batch(batch_size).prefetch(1)
return dataset
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment