Created
May 8, 2024 15:10
-
-
Save pcuenca/917353acb55ba3d3347b21e82e91926e to your computer and use it in GitHub Desktop.
Gemma Tokenizer Tests
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "f9427918", | |
"metadata": {}, | |
"source": [ | |
"## Load original and transformers tokenizers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "48a7fddf", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from huggingface_hub import hf_hub_download\n", | |
"\n", | |
"original_path = hf_hub_download(repo_id=\"google/codegemma-1.1-2b\", filename=\"tokenizer.model\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "aae9d9de", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from gemma.tokenizer import Tokenizer\n", | |
"\n", | |
"original = Tokenizer(original_path)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "06b063cf", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from transformers import GemmaTokenizer, AutoTokenizer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "a584d69f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Fails for \"main\"\n", | |
"revision = \"refs/pr/4\"\n", | |
"\n", | |
"t_fast = AutoTokenizer.from_pretrained(\"google/codegemma-1.1-7b-it\", revision=revision)\n", | |
"t_slow = GemmaTokenizer.from_pretrained(\"google/codegemma-1.1-7b-it\", revision=revision)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "72a1c087", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for s in [\n", | |
" '<start_of_turn>', '<end_of_turn>', '<mask>',\n", | |
" '<|fim_prefix|>', '<|fim_suffix|>', '<|fim_middle|>', '<|file_separator|>'\n", | |
"]:\n", | |
" encoded = original.encode(s, bos=False, eos=False)\n", | |
" assert t_fast.encode(s, add_special_tokens=False) == encoded, f\"Failed: {s}\"\n", | |
" assert t_slow.encode(s, add_special_tokens=False) == encoded, f\"Failed: {s}\"\n", | |
" assert t_fast.decode(encoded) == s, f\"Failed: {s}\"\n", | |
" assert t_slow.decode(encoded) == s, f\"Failed: {s}\"" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8ab89d7b", | |
"metadata": {}, | |
"source": [ | |
"## Verify on XNLI (validation split)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "0160405a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from datasets import load_dataset\n", | |
"from tqdm import tqdm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "a743115c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"xnli = load_dataset(\"xnli\", \"all_languages\", split=\"validation\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "9a52691b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def verify(lang, text):\n", | |
" encoded_original = original.encode(text, bos=True, eos=False)\n", | |
" encoded_fast = t_fast.encode(text)\n", | |
" encoded_slow = t_slow.encode(text)\n", | |
" assert encoded_fast == encoded_original, f\"Fast encode error: {lang} - {text}\"\n", | |
" assert encoded_slow == encoded_original, f\"Slow encode error: {lang} - {text}\"\n", | |
" decoded = original.decode(encoded_original)\n", | |
" decoded_fast = t_fast.decode(encoded_fast, skip_special_tokens=True)\n", | |
" decoded_slow = t_slow.decode(encoded_slow, skip_special_tokens=True)\n", | |
" assert decoded_fast == decoded, f\"Fast decode error: {lang} - {text}\"\n", | |
" assert decoded_slow == decoded, f\"Slow decode error: {lang} - {text}\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "f3123ffd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2490/2490 [00:30<00:00, 80.45it/s]\n" | |
] | |
} | |
], | |
"source": [ | |
"for p in tqdm(xnli[\"premise\"]):\n", | |
" for lang, text in p.items():\n", | |
" verify(lang, text)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment