Skip to content

Instantly share code, notes, and snippets.

@xiabingquan
Last active August 21, 2024 08:27
Show Gist options
  • Save xiabingquan/e10ee21c06dfe8e8d237fcd31c3927cb to your computer and use it in GitHub Desktop.
Save xiabingquan/e10ee21c06dfe8e8d237fcd31c3927cb to your computer and use it in GitHub Desktop.
Building a BPE (Bpte-Pair Encoding) tokenizer from scratch.
# A minimal example of how to implement byte-pair encoding (BPE) tokenizer from scratch in Python.
# Reference: https://github.com/karpathy/minbpe
# Contact: [email protected]
def get_stats(byte_arr):
# get the frequency of each byte pair in the text
count = {}
for pair in zip(byte_arr[:-1], byte_arr[1:]): # e.g. pair: (b'a', b' ')
count[pair] = count.get(pair, 0) + 1
return count
def merge(text_bytes, pair, new_byte):
# merge two adjacent bytes into a new byte
new_bytes = []
i = 0
while i < len(text_bytes):
if i < len(text_bytes) - 1 and text_bytes[i] == pair[0] and text_bytes[i + 1] == pair[1]: # merge
new_bytes.append(new_byte)
i += 2
else: # do not merge
new_bytes.append(text_bytes[i])
i += 1
return new_bytes
def encode(text, merges, vocab):
# encode text using byte-pair encoding (BPE)
bytes2id = {b: i for i, b in vocab.items()}
sub_words = [bytes([b]) for b in text.encode('utf-8')]
print(f"Before merged: {sub_words}")
while len(sub_words) >= 2:
pairs = [(x0, x1) for x0, x1 in zip(sub_words[:-1], sub_words[1:])]
top_pair = min(pairs, key=lambda p: merges.get(p, float("inf")))
if top_pair not in merges:
break
new_byte = b''.join(top_pair)
sub_words = merge(sub_words, top_pair, new_byte)
print(f"top pair: {top_pair}")
tokens = [bytes2id[b] for b in sub_words]
return tokens, sub_words
def decode(tokens, vocab):
# decode tokens back to text using byte-pair encoding (BPE)
sub_words = [vocab[t] for t in tokens]
text = b''.join(sub_words).decode('utf-8')
return text
if __name__ == '__main__':
# Read text
# corpus url: https://www.reedbeta.com/blog/programmers-intro-to-unicode
# download the corpus and save it as "corpus.txt" in the same directory as this script.
with open("./corpus.txt", 'r', encoding='utf-8') as fp:
text = fp.read().splitlines()[0]
# convert to byte sequence, e.g. [b'A', b' ', b'P', b'r', b'o', b'g', b'r', b'a', b'm', b'm']
text_bytes = [bytes([b]) for b in text.encode('utf-8')]
stats = get_stats(text_bytes)
# some constants
vocab_size = 300
num_merge = vocab_size - 256
assert num_merge >= 0
# merge pairs with highest frequency
merges = {}
for i in range(num_merge):
stats = get_stats(text_bytes)
top_pair = max(stats, key=stats.get)
new_byte = b''.join(top_pair)
print(f"{i+ 1}th merge: {top_pair} -> {new_byte}")
text_bytes = merge(text_bytes, top_pair, new_byte)
merges[top_pair] = 256 + i
# build vocabulary
vocab = {i: bytes([i]) for i in range(256)}
for pair, idx in merges.items():
assert idx not in vocab
vocab[idx] = b''.join(pair)
assert len(vocab) == vocab_size
# test encoding and decoding
text = "good morning"
tokens, sub_words = encode(text, merges, vocab)
print(f"{text} -> {tokens} {sub_words}")
text = decode(tokens, vocab)
print(f"{tokens} -> {text}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment