Skip to content

Instantly share code, notes, and snippets.

@eustlb
Last active March 26, 2026 19:02
Show Gist options
  • Select an option

  • Save eustlb/632016dc64aa290fc1b873c193b6b6f6 to your computer and use it in GitHub Desktop.

Select an option

Save eustlb/632016dc64aa290fc1b873c193b6b6f6 to your computer and use it in GitHub Desktop.
convert cohere asr tokenizer
"""
Convert a CohereAsr SentencePiece tokenizer (.model) to a HuggingFace fast tokenizer.
Downloads the tokenizer files from CohereLabs/cohere-transcribe-03-2026 on HuggingFace Hub.
"""
import json
from pathlib import Path
from huggingface_hub import snapshot_download
from tokenizers import AddedToken, Tokenizer, decoders, normalizers
from tokenizers.models import BPE
from transformers import PreTrainedTokenizerFast
from transformers.convert_slow_tokenizer import SentencePieceExtractor
REPO_ID = "CohereLabs/cohere-transcribe-03-2026"
REVISION = "494db8a1d34a3aeb28e9ecf61bae9e7cdef455b9"
def load_spm_proto(spm_path: str):
from sentencepiece import sentencepiece_model_pb2 as model_pb2
proto = model_pb2.ModelProto()
with open(spm_path, "rb") as f:
proto.ParseFromString(f.read())
return proto
def convert_cohere_asr_tokenizer(input_dir: str, output_dir: str):
input_dir = Path(input_dir)
spm_path = input_dir / "tokenizer.model"
special_tokens_path = input_dir / "special_tokens_map.json"
if not spm_path.exists():
raise FileNotFoundError(f"tokenizer.model not found in {input_dir}")
if not special_tokens_path.exists():
raise FileNotFoundError(f"special_tokens_map.json not found in {input_dir}")
# 1. Extract vocab and BPE merges from the SentencePiece model
extractor = SentencePieceExtractor(str(spm_path))
result = extractor.extract(model_type=None)
vocab = result["vocab"]
merges = result["merges"]
# 2. Build the BPE tokenizer
tokenizer = Tokenizer(
BPE(
vocab=vocab,
merges=merges,
unk_token="<unk>",
fuse_unk=True,
byte_fallback=True,
dropout=None,
)
)
# 3. Load special tokens
with open(special_tokens_path) as f:
special_tokens_map = json.load(f)
additional_special = set(special_tokens_map.get("additional_special_tokens", []))
core_special = {
special_tokens_map.get("bos_token", "<|startoftranscript|>"),
special_tokens_map.get("eos_token", "<|endoftext|>"),
special_tokens_map.get("unk_token", "<unk>"),
special_tokens_map.get("pad_token", "<pad>"),
}
all_special = additional_special | core_special
# 4. Add control (type 3) and user-defined (type 4) tokens from the SPM proto,
# marking them as special so skip_special_tokens=True strips them during decode
proto = load_spm_proto(str(spm_path))
spm_added_tokens = []
for token_id, piece in enumerate(proto.pieces):
if piece.type in [3, 4]:
is_special = piece.piece in all_special or piece.type == 3
spm_added_tokens.append((token_id, piece.piece, is_special))
tokenizer.add_tokens(
[
AddedToken(token, normalized=False, special=special)
for _, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
]
)
# 5. Normalizer: Prepend ▁ then replace all spaces with ▁ (matches SPM behavior exactly)
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
_normalizers = [normalizers.Prepend("▁"), normalizers.Replace(" ", "▁")]
if precompiled_charsmap:
_normalizers.insert(0, normalizers.Precompiled(precompiled_charsmap))
tokenizer.normalizer = normalizers.Sequence(_normalizers)
# 6. No pre-tokenizer needed (normalizer handles ▁ insertion).
# Decoder: convert ▁ back to space, handle byte fallback, strip leading space from prepend.
tokenizer.pre_tokenizer = None
tokenizer.decoder = decoders.Sequence(
[
decoders.Replace("▁", " "),
decoders.ByteFallback(),
decoders.Fuse(),
decoders.Strip(content=" ", left=1, right=0),
]
)
# 7. Wrap in PreTrainedTokenizerFast and save
bos_token = special_tokens_map.get("bos_token", "<|startoftranscript|>")
eos_token = special_tokens_map.get("eos_token", "<|endoftext|>")
unk_token = special_tokens_map.get("unk_token", "<unk>")
pad_token = special_tokens_map.get("pad_token", "<pad>")
fast_tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
additional_special_tokens=list(additional_special),
split_special_tokens=True,
)
fast_tokenizer.save_pretrained(output_dir)
print(f"Saved fast tokenizer to {output_dir}")
print(f" vocab_size: {fast_tokenizer.vocab_size}")
print(f" bos_token_id: {fast_tokenizer.bos_token_id}")
print(f" eos_token_id: {fast_tokenizer.eos_token_id}")
print(f" pad_token_id: {fast_tokenizer.pad_token_id}")
# 8. Verify roundtrip
test_text = "hello world"
encoded = fast_tokenizer.encode(test_text, add_special_tokens=False)
decoded = fast_tokenizer.decode(encoded)
print(f" encode/decode check: '{test_text}' -> {encoded} -> '{decoded}'")
if __name__ == "__main__":
input_dir = snapshot_download(
repo_id=REPO_ID,
revision=REVISION,
allow_patterns=["tokenizer.model", "special_tokens_map.json"],
)
output_dir = Path(__file__).resolve().parent.parent / "tokenizer"
convert_cohere_asr_tokenizer(input_dir, str(output_dir))
"""
Test that the converted CohereAsr fast tokenizer produces identical
encode AND decode results to the original SentencePiece model on real multilingual data.
Downloads the original tokenizer from CohereLabs/cohere-transcribe-03-2026 on HuggingFace Hub.
"""
from pathlib import Path
import sentencepiece
from datasets import load_dataset
from huggingface_hub import snapshot_download
from tqdm import tqdm
from transformers import PreTrainedTokenizerFast
REPO_ID = "CohereLabs/cohere-transcribe-03-2026"
REVISION = "494db8a1d34a3aeb28e9ecf61bae9e7cdef455b9"
def verify(fast_tokenizer, sp_tokenizer, lang, text):
# Encode parity
encoded_original = sp_tokenizer.EncodeAsIds(text)
encoded_fast = fast_tokenizer.encode(text, add_special_tokens=False)
assert encoded_fast == encoded_original, (
f"Encode mismatch for lang={lang}:\n"
f" text: {text!r}\n"
f" original: {encoded_original}\n"
f" fast: {encoded_fast}"
)
# Decode parity
decoded_original = sp_tokenizer.Decode(encoded_original)
decoded_fast = fast_tokenizer.decode(encoded_fast, skip_special_tokens=False)
assert decoded_fast == decoded_original, (
f"Decode mismatch for lang={lang}:\n"
f" text: {text!r}\n"
f" original: {decoded_original!r}\n"
f" fast: {decoded_fast!r}"
)
def main(spm_path: str, converted_dir: str):
# Load both tokenizers
sp_tokenizer = sentencepiece.SentencePieceProcessor()
sp_tokenizer.Load(spm_path)
fast_tokenizer = PreTrainedTokenizerFast.from_pretrained(converted_dir)
# Basic sanity checks
print("Running basic sanity checks...")
for text in ["hello world", " hello world", "yesterday it was thirty-five degrees"]:
verify(fast_tokenizer, sp_tokenizer, "en", text)
print(" Basic checks passed.")
# Test on XNLI multilingual dataset
print("Loading XNLI dataset...")
xnli = load_dataset("xnli", "all_languages", split="validation")
print("Verifying encode+decode parity on XNLI premises...")
for premise in tqdm(xnli["premise"]):
for lang, text in premise.items():
verify(fast_tokenizer, sp_tokenizer, lang, text)
print("All checks passed!")
if __name__ == "__main__":
spm_dir = snapshot_download(
repo_id=REPO_ID,
revision=REVISION,
allow_patterns=["tokenizer.model"],
)
spm_path = str(Path(spm_dir) / "tokenizer.model")
converted_dir = str(Path(__file__).resolve().parent.parent / "tokenizer")
main(spm_path, converted_dir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment