Skip to content

Instantly share code, notes, and snippets.

@pcuenca
Created May 8, 2024 15:10
Show Gist options
  • Save pcuenca/917353acb55ba3d3347b21e82e91926e to your computer and use it in GitHub Desktop.
Save pcuenca/917353acb55ba3d3347b21e82e91926e to your computer and use it in GitHub Desktop.
Gemma Tokenizer Tests
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "f9427918",
"metadata": {},
"source": [
"## Load original and transformers tokenizers"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "48a7fddf",
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import hf_hub_download\n",
"\n",
"original_path = hf_hub_download(repo_id=\"google/codegemma-1.1-2b\", filename=\"tokenizer.model\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "aae9d9de",
"metadata": {},
"outputs": [],
"source": [
"from gemma.tokenizer import Tokenizer\n",
"\n",
"original = Tokenizer(original_path)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "06b063cf",
"metadata": {},
"outputs": [],
"source": [
"from transformers import GemmaTokenizer, AutoTokenizer"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a584d69f",
"metadata": {},
"outputs": [],
"source": [
"# Fails for \"main\"\n",
"revision = \"refs/pr/4\"\n",
"\n",
"t_fast = AutoTokenizer.from_pretrained(\"google/codegemma-1.1-7b-it\", revision=revision)\n",
"t_slow = GemmaTokenizer.from_pretrained(\"google/codegemma-1.1-7b-it\", revision=revision)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "72a1c087",
"metadata": {},
"outputs": [],
"source": [
"for s in [\n",
" '<start_of_turn>', '<end_of_turn>', '<mask>',\n",
" '<|fim_prefix|>', '<|fim_suffix|>', '<|fim_middle|>', '<|file_separator|>'\n",
"]:\n",
" encoded = original.encode(s, bos=False, eos=False)\n",
" assert t_fast.encode(s, add_special_tokens=False) == encoded, f\"Failed: {s}\"\n",
" assert t_slow.encode(s, add_special_tokens=False) == encoded, f\"Failed: {s}\"\n",
" assert t_fast.decode(encoded) == s, f\"Failed: {s}\"\n",
" assert t_slow.decode(encoded) == s, f\"Failed: {s}\""
]
},
{
"cell_type": "markdown",
"id": "8ab89d7b",
"metadata": {},
"source": [
"## Verify on XNLI (validation split)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0160405a",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a743115c",
"metadata": {},
"outputs": [],
"source": [
"xnli = load_dataset(\"xnli\", \"all_languages\", split=\"validation\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9a52691b",
"metadata": {},
"outputs": [],
"source": [
"def verify(lang, text):\n",
" encoded_original = original.encode(text, bos=True, eos=False)\n",
" encoded_fast = t_fast.encode(text)\n",
" encoded_slow = t_slow.encode(text)\n",
" assert encoded_fast == encoded_original, f\"Fast encode error: {lang} - {text}\"\n",
" assert encoded_slow == encoded_original, f\"Slow encode error: {lang} - {text}\"\n",
" decoded = original.decode(encoded_original)\n",
" decoded_fast = t_fast.decode(encoded_fast, skip_special_tokens=True)\n",
" decoded_slow = t_slow.decode(encoded_slow, skip_special_tokens=True)\n",
" assert decoded_fast == decoded, f\"Fast decode error: {lang} - {text}\"\n",
" assert decoded_slow == decoded, f\"Slow decode error: {lang} - {text}\""
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f3123ffd",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2490/2490 [00:30<00:00, 80.45it/s]\n"
]
}
],
"source": [
"for p in tqdm(xnli[\"premise\"]):\n",
" for lang, text in p.items():\n",
" verify(lang, text)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment