Last active
July 17, 2020 02:36
-
-
Save stephenroller/05bbff43ad2a995abbaa18bb9d23229f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Example of dynamic/adaptive batching. | |
Author: Stephen Roller (twitter/github @stephenroller) | |
Public domain licensed. do whatever you want with this. | |
Example usage: | |
$ wget -O bible.txt http://www.gutenberg.org/cache/epub/10/pg10.txt | |
$ python adaptive.py --nonadaptive bible.txt | |
Loading data | |
Number of docs = 24669 | |
Arbitrary line of file: 2:6 But there went up a mist from the earth, and watered the whole face of the ground. | |
100%|█████████████████████████████| 24669/24669 [00:10<00:00, 2351.76it/s] | |
WARNING: Errant batch size: torch.Size([29, 216]) | |
Total number of batches = 771 | |
Average non-padding tokens / batch = 1412.0 (full = 16384) | |
Largest batchsize = torch.Size([32, 200]) | |
Widest batch = torch.Size([32, 512]) | |
Thinest batch = torch.Size([32, 32]) | |
$ python adaptive.py bible.txt | |
Loading data | |
Number of docs = 24669 | |
Arbitrary line of file: 2:6 But there went up a mist from the earth, and watered the whole face of the ground. | |
100%|█████████████████████████████| 24669/24669 [00:10<00:00, 2394.39it/s] | |
WARNING: Errant batch size: torch.Size([5, 376]) | |
Total number of batches = 83 | |
Average non-padding tokens / batch = 13119.0 (full = 16384) | |
Largest batchsize = torch.Size([680, 24]) | |
Widest batch = torch.Size([8, 512]) | |
Thinest batch = torch.Size([680, 24]) | |
""" | |
import logging | |
import argparse | |
import random | |
import torch | |
import tqdm | |
from transformers import GPT2TokenizerFast | |
MAX_LENGTH = 512 # max document length | |
BATCH_SIZE = 32 # worst case (minimum) batch size | |
TOKENS_PER_BATCH = MAX_LENGTH * BATCH_SIZE | |
BUFFER_SIZE = 4096 # trade off efficiency for memory | |
def make_documents(filebuf): | |
# gutenberg text is wrapped at 80 characters. This is roughly the equivalent | |
# of splitting up the document by \n\n. | |
all_documents = [] | |
doc = [] | |
for line in filebuf: | |
line = line.rstrip() | |
if line: | |
doc.append(line) | |
elif doc: | |
doctxt = " ".join(doc).replace(" ", " ") | |
all_documents.append(doctxt) | |
doc = [] | |
if doc: | |
doctxt = " ".join(doc).replace(" ", " ") | |
all_documents.append(doctxt) | |
return all_documents | |
def adaptive_extract_batch(buf): | |
""" | |
Pull out a batch from the buffer. Modifies the buffer in place. | |
""" | |
# sort to put like examples together | |
buf.sort() | |
batch = [] | |
longest = 0 | |
start = random.randint(0, len(buf) - 1) | |
while buf and len(batch) * longest < TOKENS_PER_BATCH: | |
cost, r, tokens = buf.pop(start) | |
longest = max(len(tokens), longest) | |
if longest % 8 != 0: | |
longest = longest + 8 - longest % 8 | |
batch.append((cost, r, tokens)) | |
start = max(start - 1, 0) | |
# we need to force batchsize to be a multiple of 8 for fp16, so put | |
# things back into the buffer | |
while len(batch) > 8 and len(batch) % 8 != 0: | |
buf.append(batch.pop(0)) | |
batch_tensor = torch.zeros((len(batch), longest), dtype=torch.int64) | |
for i, (_, _, tokens) in enumerate(batch): | |
batch_tensor[i, : len(tokens)] = tokens | |
return batch_tensor | |
def naive_extract_batch(buf): | |
""" | |
Pull out a batch naively (in order) | |
""" | |
batch = [] | |
longest = 0 | |
for i in range(BATCH_SIZE): | |
if not buf: | |
break | |
cost, r, tokens = buf.pop(0) | |
batch.append(tokens) | |
longest = max(longest, len(tokens)) | |
if longest % 8 != 0: | |
longest = longest + 8 - longest % 8 | |
batch_tensor = torch.zeros((len(batch), longest), dtype=torch.int64) | |
for i, tokens in enumerate(batch): | |
batch_tensor[i, : len(tokens)] = tokens | |
return batch_tensor | |
def batcher(stream, tokenizer, extract_fn): | |
buf = [] | |
for item in stream: | |
tokenized = tokenizer( | |
item, max_length=MAX_LENGTH, return_tensors='pt', truncation=True, | |
)['input_ids'][0] | |
# truncate | |
cost = len(tokenized) // 8 | |
buf.append((cost, random.random(), tokenized)) | |
if len(buf) >= BUFFER_SIZE: | |
yield extract_fn(buf) | |
# ate all the data. handle whatever is leftover in the buffer | |
while buf: | |
yield extract_fn(buf) | |
def main(args=None): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('file', type=argparse.FileType('r')) | |
parser.add_argument('--nonadaptive', action='store_true') | |
opts = parser.parse_args(args) | |
# load up data | |
print("Loading data") | |
docs = make_documents(opts.file) | |
print(f"Number of docs = {len(docs)}") | |
print(f"Arbitrary line of file: {docs[42]}") | |
# initialize the tokenizer | |
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | |
max_width = 0 | |
max_height = 0 | |
min_width = MAX_LENGTH + 1 | |
widest_batch = None | |
thinest_batch = None | |
longest_batch = None | |
tokens_per_batch = [] | |
if opts.nonadaptive: | |
extract = naive_extract_batch | |
else: | |
extract = adaptive_extract_batch | |
for b, batch in enumerate(batcher(tqdm.tqdm(docs, ncols=74), tokenizer, extract)): | |
height, width = batch.shape | |
tokens_per_batch.append((batch != 0).sum()) | |
assert width % 8 == 0 | |
assert height > 0 | |
assert width > 0 | |
if height % 8 != 0: | |
print(f"WARNING: Errant batch size: {batch.shape}") | |
if height >= max_height: | |
longest_batch = batch.shape | |
max_height = height | |
if width >= max_width: | |
widest_batch = batch.shape | |
max_width = width | |
if width <= min_width: | |
thinest_batch = batch.shape | |
min_width = width | |
print(f"Total number of batches = {b + 1}") | |
tokens_per_batch = sum(tokens_per_batch) / len(tokens_per_batch) | |
print( | |
f"Average non-padding tokens / batch = {tokens_per_batch:.1f} (full = {TOKENS_PER_BATCH})" | |
) | |
print(f"Largest batchsize = {longest_batch}") | |
print(f"Widest batch = {widest_batch}") | |
print(f"Thinest batch = {thinest_batch}") | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment