Created
June 24, 2019 03:29
-
-
Save 8enmann/86be66859735fb7e33a2f36041fa433c to your computer and use it in GitHub Desktop.
Approximate BPE implementation.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Implements an approximate BPE encoding over bytes with some tricks for efficiency. | |
https://arxiv.org/pdf/1508.07909.pdf section 3.2. | |
Basic algorithm from the paper: | |
Initialize the vocab with the character vocabulary | |
Each word is a sequence of characters plus an enod of word symbol '·' | |
Count all symbol pairs | |
Replace each occurence of the most frequent pair ('a', 'b') with 'ab'. | |
Each merge represents a character n-gram | |
Frequent n-grams are merged into a single symbol. | |
Repeat until max vocab size or computation budget is reached. | |
Unlike the paper, this implementation operates directly on utf-8 bytes, | |
so should work for any language or data type with no modification. | |
It provides the option to do multiple replacements per iteration for increased speed. | |
Encoding using a computed vocab is done greedily instead of by the standard algorithm. | |
TODO: benchmark against original. | |
""" | |
import multiprocessing as mp | |
from collections import Counter, deque | |
from typing import Dict, Iterable, List, Set, Tuple | |
import tqdm | |
def get_pairs(seq: Iterable) -> Iterable[Tuple]: | |
"""Yield a sliding window of length 2 from seq.""" | |
d = deque(maxlen=2) | |
# Consume first bit | |
it = iter(seq) | |
for _ in range(2): | |
d.append(next(it)) | |
yield tuple(d) | |
for i in it: | |
d.append(i) | |
yield tuple(d) | |
class Worker(mp.Process): | |
"""Computes counts on a subset of the corpus. | |
Waits for the master to tell it what to merge based on its siblings. | |
Queues are child -> parent only. | |
`top_k` is read only. | |
""" | |
def __init__( | |
self, | |
top_k_ready: mp.Condition, | |
top_k: 'DictProxy', | |
count_q: mp.Queue, | |
vocab_q: mp.Queue, | |
corpus: str): | |
super(Worker, self).__init__() | |
self.top_k_ready = top_k_ready | |
self.top_k = top_k | |
self.vocab_q = vocab_q | |
self.count_q = count_q | |
self.corpus = corpus | |
self.byte_list: Iterable[bytes] = None | |
def run(self): | |
"""This shouldn't be called directly; call `worker.start()`.""" | |
print('started', self.name) | |
self.byte_list = str_to_byte_list(self.corpus) | |
self.vocab = set(self.byte_list) | |
self.vocab_q.put(self.vocab) | |
while True: | |
counts = Counter(get_pairs(self.byte_list)) | |
self.count_q.put(counts) | |
# Wait for main thread to send top k merges | |
with self.top_k_ready: | |
self.top_k_ready.wait() | |
if len(self.top_k) == 0: | |
break | |
self.byte_list = list(merge(self.top_k, self.byte_list)) | |
def compute_vocab_multi( | |
corpus: str, | |
max_vocab_size:int=3000, | |
max_merges:int=10, top_k=1, | |
n:int=mp.cpu_count()) -> Set[bytes]: | |
"""Multiprocess implementation of approximate BPE. | |
Divides the corpus among n workers. | |
Args: | |
corpus: The corpus to encode. Could scale better by taking a list of filenames. | |
max_vocab_size: Stop after generating this many vocab entries. | |
max_merges: Stop after this many rounds. | |
top_k: Each round merge the top k pairs. Standard BPE sets top_k=1. | |
Returns: | |
A set of all the vocab entries generated, each of which is a `bytes`. | |
""" | |
top_k_ready = mp.Condition() | |
vocab_q = mp.Queue() | |
count_q = mp.Queue() | |
chunk_size = len(corpus) // n | |
counts = Counter() | |
vocab = set() | |
with mp.Manager() as manager: | |
to_merge = manager.dict() | |
procs = [] | |
print('starting workers') | |
for i in range(n): | |
procs.append(Worker( | |
top_k_ready, | |
to_merge, | |
count_q, | |
vocab_q, | |
# These overlap on purpose | |
corpus[i * chunk_size:(i+1) * chunk_size + 1], | |
max_merges | |
)) | |
procs[-1].start() | |
# Get inital vocab from each worker. | |
print('waiting for vocab from worker') | |
for _ in range(n): | |
vocab.update(vocab_q.get()) | |
print('got vocab', vocab) | |
for i in range(max_merges): | |
# Get counts from each worker. | |
for _ in range(n): | |
counts.update(count_q.get()) | |
print(counts) | |
to_merge.clear() | |
to_merge.update({x[0]: b''.join(x[0]) for x in counts.most_common(top_k)}) | |
vocab.update(to_merge.values()) | |
with top_k_ready: | |
top_k_ready.notify_all() | |
if len(vocab) >= max_vocab_size: | |
break | |
# Tell workers to stop. | |
to_merge.clear() | |
with top_k_ready: | |
top_k_ready.notify_all() | |
for p in procs: | |
p.join(1) | |
return vocab | |
def merge(to_merge: Dict[Tuple[bytes], bytes], seq: Iterable) -> Iterable: | |
"""Given a set of requested merges, go through the sequence and do the merges.""" | |
to_merge = {x: b''.join(x) for x in to_merge.keys()} | |
just_merged = False | |
for pair in get_pairs(seq): | |
if just_merged: | |
just_merged = False | |
continue | |
if pair in to_merge: | |
just_merged = True | |
yield to_merge[pair] | |
else: | |
yield pair[0] | |
if not just_merged: | |
yield pair[1] | |
def str_to_byte_list(s: str) -> Iterable[bytes]: | |
return [bytes([x]) for x in s.encode('utf8')] | |
def compute_vocab(corpus: str, max_vocab_size:int=3000, max_merges:int=10, top_k=1) -> Set[bytes]: | |
"""Single threaded implementation of approximate BPE. | |
Args: | |
corpus: The corpus to encode. Could scale better by taking a list of filenames. | |
max_vocab_size: Stop after generating this many vocab entries. | |
max_merges: Stop after this many rounds. | |
top_k: Each round merge the top k pairs. Standard BPE sets top_k=1. | |
Returns: | |
A set of all the vocab entries generated, each of which is a `bytes`. | |
""" | |
if len(corpus) < min(max_merges, max_vocab_size): | |
raise Exception('Corpus must be bigger than max_merges') | |
l = str_to_byte_list(corpus) | |
vocab = set(l) | |
for i in tqdm.trange(max_merges): | |
counts = Counter(get_pairs(l)) | |
# Merge the most common. | |
to_merge = {x[0]: b''.join(x[0]) for x in counts.most_common(top_k)} | |
vocab.update(to_merge.values()) | |
l = list(merge(to_merge, l)) | |
if len(vocab) >= max_vocab_size: | |
break | |
return vocab | |
class Encoder: | |
DEFAULT_VOCAB_FILENAME = 'vocab.bpe' | |
# Null bytes unlikely to occur in natural encoded text. | |
DELIM = b'\0\n' | |
# Must be 2 characters long because otherwise probably won't have intermediate merges for the greedy encoder to pick up. | |
EOF = b'\0F' | |
UNK = b'\0UNK' | |
def __init__(self, vocab: Iterable[bytes]=None, vocab_file: str=DEFAULT_VOCAB_FILENAME): | |
if vocab: | |
self.vocab = vocab | |
else: | |
self.vocab = self.load(vocab_file) | |
# Append special characters. | |
self.vocab += [self.EOF, self.UNK] | |
# break keys into tuples for faster match? | |
self.encoder = {x:i for i,x in enumerate(vocab)} | |
self.decoder = {i:x for i,x in enumerate(vocab)} | |
self.max_length = max(map(len, self.vocab)) | |
self.UNK_EMB = len(self.vocab) - 1 | |
def encode(self, corpus: str) -> Iterable[int]: | |
"""Greedily encode `corpus` according to the vocab.""" | |
b = corpus.encode('utf8') | |
start = 0 | |
while start < len(b): | |
match = self.UNK_EMB | |
for end in range(0, self.max_length): | |
end += 1 + start | |
substr = b[start:end] | |
new_match = self.encoder.get(substr) | |
if new_match is not None: | |
match = new_match | |
if end < len(b): | |
continue | |
yield match | |
start += max(1, len(substr) - 1) | |
break | |
def decode(self, corpus: Iterable[int], errors='ignore') -> str: | |
"""Decode `corpus` according to the vocab.""" | |
return b''.join([self.decoder[x] for x in corpus]).decode('utf8', errors=errors) | |
@classmethod | |
def save(cls, vocab: Iterable[bytes], filename:str=DEFAULT_VOCAB_FILENAME): | |
with open(filename, 'wb') as f: | |
for v in vocab: | |
f.write(v + cls.DELIM) | |
@classmethod | |
def load(cls, filename:str=DEFAULT_VOCAB_FILENAME): | |
with open(filename, 'rb') as f: | |
return f.read().split(cls.DELIM)[:-1] | |
def main(): | |
CORPUS_FILE = '/Users/ben/data/wikitext-2/wiki.train.tokens' | |
with open(CORPUS_FILE) as f: | |
corpus = f.read() | |
print(len(corpus)) | |
vocab = compute_vocab(corpus[:10000], max_merges=100, top_k=10) | |
print(len(vocab)) | |
# Save the mapping | |
Encoder.save(vocab) | |
print(len(Encoder.load())) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Tests for bpe.py. | |
One of the tests uses a separate process. | |
The others mock out multiprocessing functionality, so should be lightweight. | |
""" | |
from unittest import mock | |
import pytest | |
import multiprocessing | |
import bpe | |
def test_get_pairs(): | |
s = 'aaabbb' | |
pairs = list(bpe.get_pairs(s)) | |
assert len(pairs) == len(s) - 1 | |
assert list(bpe.get_pairs('abc')) == [('a', 'b'), ('b', 'c')] | |
def test_str_to_byte_list(): | |
assert [b'a',b'b'] == bpe.str_to_byte_list('ab') | |
def test_merge(): | |
l = bpe.str_to_byte_list('abc') | |
to_merge = {x: b''.join(x) for x in [(b'a',b'b')]} | |
merged = list(bpe.merge(to_merge, l)) | |
assert merged == [b'ab', b'c'] | |
def test_compute_vocab_simple(): | |
TEST = 'The quick brown fox jumped over the lazy dog. Wow! Amazing.' | |
vocab = bpe.compute_vocab(TEST) | |
assert 41 == len(vocab) | |
def test_encode(): | |
encoder = bpe.Encoder(bpe.str_to_byte_list('abcdef')) | |
test_str = 'aabb' | |
encoded = list(encoder.encode(test_str)) | |
assert encoded == [0, 0, 1, 1] | |
assert test_str == encoder.decode(encoded) | |
# Test UNK | |
assert encoder.UNK.decode('utf8') == encoder.decode(encoder.encode('t')) | |
@mock.patch('multiprocessing.Queue') | |
@mock.patch('bpe.Worker') | |
@mock.patch('multiprocessing.Condition') | |
def test_compute_vocab_multi(Condition, MockWorker, Queue): | |
q = Queue.return_value | |
q.get.side_effect = [ | |
# Return initial vocab. | |
{b'a'}, | |
# Return the first set of counts. | |
{(b'a', b'a'): 3, (b'b', b'b'): 2}] | |
out = bpe.compute_vocab_multi('aaabb', n=1, max_vocab_size=2) | |
assert 'aaabb' in MockWorker.call_args[0] | |
assert out == {b'aa', b'a'} | |
@mock.patch('multiprocessing.Queue') | |
@mock.patch('bpe.Worker') | |
@mock.patch('multiprocessing.Condition') | |
def test_compute_vocab_multi_corpus_partition(Condition, MockWorker, Queue): | |
# Get the instance | |
q = Queue.return_value | |
# Return the same thing every time. | |
q.get.return_value = [] | |
out = bpe.compute_vocab_multi('aaabb', n=2, max_vocab_size=0) | |
assert 'aaa' in MockWorker.call_args_list[0][0] | |
assert 'abb' in MockWorker.call_args_list[1][0] | |
# Queue returned nothing every time. | |
assert out == set() | |
def test_worker(): | |
top_k_ready = multiprocessing.Condition() | |
with multiprocessing.Manager() as m: | |
top_k = m.dict() | |
count_q = multiprocessing.Queue() | |
vocab_q = multiprocessing.Queue() | |
worker = bpe.Worker(top_k_ready, top_k, count_q, vocab_q, 'aaabb') | |
worker.start() | |
assert vocab_q.get() == {b'a', b'b'} | |
counts = count_q.get() | |
assert counts == {(b'a', b'a'): 2, (b'a', b'b'): 1, (b'b', b'b'): 1} | |
top_k.update({x[0]: b''.join(x[0]) for x in counts.most_common(2)}) | |
with top_k_ready: | |
top_k_ready.notify() | |
# Round 2. | |
counts = count_q.get() | |
assert counts == {(b'aa', b'ab'): 1, (b'ab', b'b'): 1} | |
# Finish. | |
top_k.clear() | |
with top_k_ready: | |
top_k_ready.notify() | |
worker.join() | |
assert not worker.is_alive() | |
if __name__ == '__main__': | |
pytest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment