Last active
August 21, 2024 08:27
-
-
Save xiabingquan/e10ee21c06dfe8e8d237fcd31c3927cb to your computer and use it in GitHub Desktop.
Building a BPE (Bpte-Pair Encoding) tokenizer from scratch.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# A minimal example of how to implement byte-pair encoding (BPE) tokenizer from scratch in Python. | |
# Reference: https://github.com/karpathy/minbpe | |
# Contact: [email protected] | |
def get_stats(byte_arr): | |
# get the frequency of each byte pair in the text | |
count = {} | |
for pair in zip(byte_arr[:-1], byte_arr[1:]): # e.g. pair: (b'a', b' ') | |
count[pair] = count.get(pair, 0) + 1 | |
return count | |
def merge(text_bytes, pair, new_byte): | |
# merge two adjacent bytes into a new byte | |
new_bytes = [] | |
i = 0 | |
while i < len(text_bytes): | |
if i < len(text_bytes) - 1 and text_bytes[i] == pair[0] and text_bytes[i + 1] == pair[1]: # merge | |
new_bytes.append(new_byte) | |
i += 2 | |
else: # do not merge | |
new_bytes.append(text_bytes[i]) | |
i += 1 | |
return new_bytes | |
def encode(text, merges, vocab): | |
# encode text using byte-pair encoding (BPE) | |
bytes2id = {b: i for i, b in vocab.items()} | |
sub_words = [bytes([b]) for b in text.encode('utf-8')] | |
print(f"Before merged: {sub_words}") | |
while len(sub_words) >= 2: | |
pairs = [(x0, x1) for x0, x1 in zip(sub_words[:-1], sub_words[1:])] | |
top_pair = min(pairs, key=lambda p: merges.get(p, float("inf"))) | |
if top_pair not in merges: | |
break | |
new_byte = b''.join(top_pair) | |
sub_words = merge(sub_words, top_pair, new_byte) | |
print(f"top pair: {top_pair}") | |
tokens = [bytes2id[b] for b in sub_words] | |
return tokens, sub_words | |
def decode(tokens, vocab): | |
# decode tokens back to text using byte-pair encoding (BPE) | |
sub_words = [vocab[t] for t in tokens] | |
text = b''.join(sub_words).decode('utf-8') | |
return text | |
if __name__ == '__main__': | |
# Read text | |
# corpus url: https://www.reedbeta.com/blog/programmers-intro-to-unicode | |
# download the corpus and save it as "corpus.txt" in the same directory as this script. | |
with open("./corpus.txt", 'r', encoding='utf-8') as fp: | |
text = fp.read().splitlines()[0] | |
# convert to byte sequence, e.g. [b'A', b' ', b'P', b'r', b'o', b'g', b'r', b'a', b'm', b'm'] | |
text_bytes = [bytes([b]) for b in text.encode('utf-8')] | |
stats = get_stats(text_bytes) | |
# some constants | |
vocab_size = 300 | |
num_merge = vocab_size - 256 | |
assert num_merge >= 0 | |
# merge pairs with highest frequency | |
merges = {} | |
for i in range(num_merge): | |
stats = get_stats(text_bytes) | |
top_pair = max(stats, key=stats.get) | |
new_byte = b''.join(top_pair) | |
print(f"{i+ 1}th merge: {top_pair} -> {new_byte}") | |
text_bytes = merge(text_bytes, top_pair, new_byte) | |
merges[top_pair] = 256 + i | |
# build vocabulary | |
vocab = {i: bytes([i]) for i in range(256)} | |
for pair, idx in merges.items(): | |
assert idx not in vocab | |
vocab[idx] = b''.join(pair) | |
assert len(vocab) == vocab_size | |
# test encoding and decoding | |
text = "good morning" | |
tokens, sub_words = encode(text, merges, vocab) | |
print(f"{text} -> {tokens} {sub_words}") | |
text = decode(tokens, vocab) | |
print(f"{tokens} -> {text}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment