Created
February 27, 2020 16:14
-
-
Save andrewm4894/0a1210c4efee5d7e3b4a80607b2de8c2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def calc_batches(train_max: int, train_every: int, n: int) -> dict: | |
batches = dict() | |
# loop over up to as many records as you have | |
for batch in range(n): | |
# work out the start of the batch, with a max() to handle first batch | |
start = max(train_every * batch, 1) | |
# work out the end of the batch, with a min() to handle last batch | |
end = min(train_max+(train_every * batch), n) | |
# add batch info to the dictionary | |
batches[batch+1] = {"start": start, "end": end} | |
# break out once you have assigned all rows to a batch | |
if end == n: | |
break | |
return batches | |
calc_batches(train_max=1000, train_every=500, n=10000) | |
''' | |
{1: {'start': 1, 'end': 1000}, | |
2: {'start': 500, 'end': 1500}, | |
3: {'start': 1000, 'end': 2000}, | |
4: {'start': 1500, 'end': 2500}, | |
5: {'start': 2000, 'end': 3000}, | |
6: {'start': 2500, 'end': 3500}, | |
7: {'start': 3000, 'end': 4000}, | |
8: {'start': 3500, 'end': 4500}, | |
9: {'start': 4000, 'end': 5000}, | |
10: {'start': 4500, 'end': 5500}, | |
11: {'start': 5000, 'end': 6000}, | |
12: {'start': 5500, 'end': 6500}, | |
13: {'start': 6000, 'end': 7000}, | |
14: {'start': 6500, 'end': 7500}, | |
15: {'start': 7000, 'end': 8000}, | |
16: {'start': 7500, 'end': 8500}, | |
17: {'start': 8000, 'end': 9000}, | |
18: {'start': 8500, 'end': 9500}, | |
19: {'start': 9000, 'end': 10000}} | |
''' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment