Created
February 22, 2024 18:38
-
-
Save iamlemec/52eaa4961762efb9c064b871a67f6cc6 to your computer and use it in GitHub Desktop.
Compare tokenization results between `llama-cpp-python` and Huggingface `tokenizers`.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def check_tokenizer(mod_ll, mod_hf, data, max_rows=None): | |
from llama_cpp import Llama | |
from transformers import AutoTokenizer | |
from Levenshtein import editops | |
from termcolor import cprint | |
# load models | |
if type(mod_ll) is str: | |
mod_ll = Llama(mod_ll, verbose=False) | |
if type(mod_hf) is str: | |
mod_hf = AutoTokenizer.from_pretrained(mod_hf) | |
# load data | |
if type(data) is str: | |
data = open(data).read().splitlines() | |
if max_rows is not None: | |
data = data[:max_rows] | |
# compute token ids | |
ids_ll = [mod_ll.tokenize(text.encode('utf-8')) for text in data] | |
ids_st = [mod_hf.encode(text) for text in data] | |
def tokmap(i, replace=False): | |
tok = mod_hf._tokenizer.id_to_token(i) | |
if tok.startswith('##'): | |
return tok[2:] | |
else: | |
pre = '_' if replace else ' ' | |
return f'{pre}{tok}' | |
# compare token ids | |
for i, (id_ll, id_st) in enumerate(zip(ids_ll, ids_st)): | |
if id_ll != id_st: | |
print(f'Mismatch at index {i}') | |
ops = { | |
i1: (op, i2) for op, i1, i2 in editops(id_ll, id_st) | |
} | |
for pos1, id1 in enumerate(id_ll): | |
if pos1 in ops: | |
op, pos2 = ops[pos1] | |
id2 = id_st[pos2] | |
tok1 = tokmap(id1, replace=True) | |
tok2 = tokmap(id2, replace=True) | |
if op == 'insert': | |
cprint(f'[+{tok1}]', color='green', attrs=['bold'], end='') | |
elif op == 'delete': | |
cprint(f'[-{tok1}]', color='red', attrs=['bold'], end='') | |
elif op == 'replace': | |
print('[', end='') | |
cprint(f'{tok1}', color='red', attrs=['bold'], end='') | |
cprint(f'→{tok2}', color='green', attrs=['bold'], end='') | |
print(']', end='') | |
else: | |
tok1 = tokmap(id1, replace=False) | |
print(tok1, end='') | |
print('\n') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment