Created
July 14, 2017 07:20
-
-
Save hackintoshrao/4e23fdd383808228a9f70c8173545d5b to your computer and use it in GitHub Desktop.
Convert the RNN's to batches
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def get_batches(arr, n_seqs, n_steps): | |
| '''Create a generator that returns batches of size | |
| n_seqs x n_steps from arr. | |
| Arguments | |
| --------- | |
| arr: Array you want to make batches from | |
| n_seqs: Batch size, the number of sequences per batch | |
| n_steps: Number of sequence steps per batch | |
| ''' | |
| # Get the number of characters per batch and number of batches we can make | |
| characters_per_batch = n_seqs * n_steps | |
| n_batches = len(arr)//characters_per_batch | |
| # Keep only enough characters to make full batches | |
| arr = arr[:n_batches * characters_per_batch] | |
| # Reshape into n_seqs rows | |
| arr = arr.reshape((n_seqs, -1)) | |
| for n in range(0, arr.shape[1], n_steps): | |
| # The features | |
| x = arr[:, n:n+n_steps] | |
| # The targets, shifted by one | |
| y = np.zeros_like(x) | |
| y[:, :-1], y[:, -1] = x[:, 1:], x[:, 0] | |
| yield x, y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment