Skip to content

Instantly share code, notes, and snippets.

@jamescalam
Created May 4, 2020 14:56
Show Gist options
  • Select an option

  • Save jamescalam/4ffd095324573fec23a6ebe2c5965fbf to your computer and use it in GitHub Desktop.

Select an option

Save jamescalam/4ffd095324573fec23a6ebe2c5965fbf to your computer and use it in GitHub Desktop.
Creating the dataset object ready for training in tensorflow.
# define the input/target data splitting function
def split_xy(seq):
input_data = seq[:-1]
target_data = seq[1:]
return input_data, target_data
SEQLEN = 100 # the number of characters in a single sequence
BATCHSIZE = 64 # how many sequences in a single training batch
BUFFER = 10000 # how many elements are contained within a single shuffling space
# create training dataset with tf dataset api
dataset = tf.data.Dataset.from_tensor_slices(data_idx)
# batch method allows conversion of individual characters to sequences of a desired size
sequences = dataset.batch(SEQLEN + 1, drop_remainder=True)
dataset = sequences.map(split_xy) # mapping dataset sequences to input-target segments
# shuffle the dataset AND batch into batches of 64 sequences
dataset = dataset.shuffle(BUFFER).batch(BATCHSIZE, drop_remainder=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment