Created
August 7, 2013 20:08
-
-
Save mccutchen/6178075 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def gen_batches(xs, size, overlap=0): | |
| """Given a sequence xs and a batch size, yield batches from the sequence as | |
| lists of length size, where the last batch might be smaller than the | |
| rest. | |
| An optional overlap amount may be given as a float specifying a percentage | |
| of the batch size or as an int specifying the number of items to overlap. | |
| """ | |
| assert size > 0 | |
| if isinstance(overlap, float): | |
| assert 0 <= overlap < 1 | |
| offset = int(size * overlap) | |
| else: | |
| assert 0 <= overlap < size | |
| offset = overlap | |
| acc = [] | |
| for i, x in enumerate(xs): | |
| if i and len(acc) % size == 0: | |
| yield acc | |
| acc = acc[-offset:] if offset else [] | |
| acc.append(x) | |
| if acc: | |
| yield acc |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment