Created
May 4, 2020 14:56
-
-
Save jamescalam/4ffd095324573fec23a6ebe2c5965fbf to your computer and use it in GitHub Desktop.
Creating the dataset object ready for training in tensorflow.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # define the input/target data splitting function | |
| def split_xy(seq): | |
| input_data = seq[:-1] | |
| target_data = seq[1:] | |
| return input_data, target_data | |
| SEQLEN = 100 # the number of characters in a single sequence | |
| BATCHSIZE = 64 # how many sequences in a single training batch | |
| BUFFER = 10000 # how many elements are contained within a single shuffling space | |
| # create training dataset with tf dataset api | |
| dataset = tf.data.Dataset.from_tensor_slices(data_idx) | |
| # batch method allows conversion of individual characters to sequences of a desired size | |
| sequences = dataset.batch(SEQLEN + 1, drop_remainder=True) | |
| dataset = sequences.map(split_xy) # mapping dataset sequences to input-target segments | |
| # shuffle the dataset AND batch into batches of 64 sequences | |
| dataset = dataset.shuffle(BUFFER).batch(BATCHSIZE, drop_remainder=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment