Last active
March 26, 2026 19:02
-
-
Save eustlb/632016dc64aa290fc1b873c193b6b6f6 to your computer and use it in GitHub Desktop.
convert cohere asr tokenizer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Convert a CohereAsr SentencePiece tokenizer (.model) to a HuggingFace fast tokenizer. | |
| Downloads the tokenizer files from CohereLabs/cohere-transcribe-03-2026 on HuggingFace Hub. | |
| """ | |
| import json | |
| from pathlib import Path | |
| from huggingface_hub import snapshot_download | |
| from tokenizers import AddedToken, Tokenizer, decoders, normalizers | |
| from tokenizers.models import BPE | |
| from transformers import PreTrainedTokenizerFast | |
| from transformers.convert_slow_tokenizer import SentencePieceExtractor | |
| REPO_ID = "CohereLabs/cohere-transcribe-03-2026" | |
| REVISION = "494db8a1d34a3aeb28e9ecf61bae9e7cdef455b9" | |
| def load_spm_proto(spm_path: str): | |
| from sentencepiece import sentencepiece_model_pb2 as model_pb2 | |
| proto = model_pb2.ModelProto() | |
| with open(spm_path, "rb") as f: | |
| proto.ParseFromString(f.read()) | |
| return proto | |
| def convert_cohere_asr_tokenizer(input_dir: str, output_dir: str): | |
| input_dir = Path(input_dir) | |
| spm_path = input_dir / "tokenizer.model" | |
| special_tokens_path = input_dir / "special_tokens_map.json" | |
| if not spm_path.exists(): | |
| raise FileNotFoundError(f"tokenizer.model not found in {input_dir}") | |
| if not special_tokens_path.exists(): | |
| raise FileNotFoundError(f"special_tokens_map.json not found in {input_dir}") | |
| # 1. Extract vocab and BPE merges from the SentencePiece model | |
| extractor = SentencePieceExtractor(str(spm_path)) | |
| result = extractor.extract(model_type=None) | |
| vocab = result["vocab"] | |
| merges = result["merges"] | |
| # 2. Build the BPE tokenizer | |
| tokenizer = Tokenizer( | |
| BPE( | |
| vocab=vocab, | |
| merges=merges, | |
| unk_token="<unk>", | |
| fuse_unk=True, | |
| byte_fallback=True, | |
| dropout=None, | |
| ) | |
| ) | |
| # 3. Load special tokens | |
| with open(special_tokens_path) as f: | |
| special_tokens_map = json.load(f) | |
| additional_special = set(special_tokens_map.get("additional_special_tokens", [])) | |
| core_special = { | |
| special_tokens_map.get("bos_token", "<|startoftranscript|>"), | |
| special_tokens_map.get("eos_token", "<|endoftext|>"), | |
| special_tokens_map.get("unk_token", "<unk>"), | |
| special_tokens_map.get("pad_token", "<pad>"), | |
| } | |
| all_special = additional_special | core_special | |
| # 4. Add control (type 3) and user-defined (type 4) tokens from the SPM proto, | |
| # marking them as special so skip_special_tokens=True strips them during decode | |
| proto = load_spm_proto(str(spm_path)) | |
| spm_added_tokens = [] | |
| for token_id, piece in enumerate(proto.pieces): | |
| if piece.type in [3, 4]: | |
| is_special = piece.piece in all_special or piece.type == 3 | |
| spm_added_tokens.append((token_id, piece.piece, is_special)) | |
| tokenizer.add_tokens( | |
| [ | |
| AddedToken(token, normalized=False, special=special) | |
| for _, token, special in sorted(spm_added_tokens, key=lambda x: x[0]) | |
| ] | |
| ) | |
| # 5. Normalizer: Prepend ▁ then replace all spaces with ▁ (matches SPM behavior exactly) | |
| precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap | |
| _normalizers = [normalizers.Prepend("▁"), normalizers.Replace(" ", "▁")] | |
| if precompiled_charsmap: | |
| _normalizers.insert(0, normalizers.Precompiled(precompiled_charsmap)) | |
| tokenizer.normalizer = normalizers.Sequence(_normalizers) | |
| # 6. No pre-tokenizer needed (normalizer handles ▁ insertion). | |
| # Decoder: convert ▁ back to space, handle byte fallback, strip leading space from prepend. | |
| tokenizer.pre_tokenizer = None | |
| tokenizer.decoder = decoders.Sequence( | |
| [ | |
| decoders.Replace("▁", " "), | |
| decoders.ByteFallback(), | |
| decoders.Fuse(), | |
| decoders.Strip(content=" ", left=1, right=0), | |
| ] | |
| ) | |
| # 7. Wrap in PreTrainedTokenizerFast and save | |
| bos_token = special_tokens_map.get("bos_token", "<|startoftranscript|>") | |
| eos_token = special_tokens_map.get("eos_token", "<|endoftext|>") | |
| unk_token = special_tokens_map.get("unk_token", "<unk>") | |
| pad_token = special_tokens_map.get("pad_token", "<pad>") | |
| fast_tokenizer = PreTrainedTokenizerFast( | |
| tokenizer_object=tokenizer, | |
| bos_token=bos_token, | |
| eos_token=eos_token, | |
| unk_token=unk_token, | |
| pad_token=pad_token, | |
| additional_special_tokens=list(additional_special), | |
| split_special_tokens=True, | |
| ) | |
| fast_tokenizer.save_pretrained(output_dir) | |
| print(f"Saved fast tokenizer to {output_dir}") | |
| print(f" vocab_size: {fast_tokenizer.vocab_size}") | |
| print(f" bos_token_id: {fast_tokenizer.bos_token_id}") | |
| print(f" eos_token_id: {fast_tokenizer.eos_token_id}") | |
| print(f" pad_token_id: {fast_tokenizer.pad_token_id}") | |
| # 8. Verify roundtrip | |
| test_text = "hello world" | |
| encoded = fast_tokenizer.encode(test_text, add_special_tokens=False) | |
| decoded = fast_tokenizer.decode(encoded) | |
| print(f" encode/decode check: '{test_text}' -> {encoded} -> '{decoded}'") | |
| if __name__ == "__main__": | |
| input_dir = snapshot_download( | |
| repo_id=REPO_ID, | |
| revision=REVISION, | |
| allow_patterns=["tokenizer.model", "special_tokens_map.json"], | |
| ) | |
| output_dir = Path(__file__).resolve().parent.parent / "tokenizer" | |
| convert_cohere_asr_tokenizer(input_dir, str(output_dir)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Test that the converted CohereAsr fast tokenizer produces identical | |
| encode AND decode results to the original SentencePiece model on real multilingual data. | |
| Downloads the original tokenizer from CohereLabs/cohere-transcribe-03-2026 on HuggingFace Hub. | |
| """ | |
| from pathlib import Path | |
| import sentencepiece | |
| from datasets import load_dataset | |
| from huggingface_hub import snapshot_download | |
| from tqdm import tqdm | |
| from transformers import PreTrainedTokenizerFast | |
| REPO_ID = "CohereLabs/cohere-transcribe-03-2026" | |
| REVISION = "494db8a1d34a3aeb28e9ecf61bae9e7cdef455b9" | |
| def verify(fast_tokenizer, sp_tokenizer, lang, text): | |
| # Encode parity | |
| encoded_original = sp_tokenizer.EncodeAsIds(text) | |
| encoded_fast = fast_tokenizer.encode(text, add_special_tokens=False) | |
| assert encoded_fast == encoded_original, ( | |
| f"Encode mismatch for lang={lang}:\n" | |
| f" text: {text!r}\n" | |
| f" original: {encoded_original}\n" | |
| f" fast: {encoded_fast}" | |
| ) | |
| # Decode parity | |
| decoded_original = sp_tokenizer.Decode(encoded_original) | |
| decoded_fast = fast_tokenizer.decode(encoded_fast, skip_special_tokens=False) | |
| assert decoded_fast == decoded_original, ( | |
| f"Decode mismatch for lang={lang}:\n" | |
| f" text: {text!r}\n" | |
| f" original: {decoded_original!r}\n" | |
| f" fast: {decoded_fast!r}" | |
| ) | |
| def main(spm_path: str, converted_dir: str): | |
| # Load both tokenizers | |
| sp_tokenizer = sentencepiece.SentencePieceProcessor() | |
| sp_tokenizer.Load(spm_path) | |
| fast_tokenizer = PreTrainedTokenizerFast.from_pretrained(converted_dir) | |
| # Basic sanity checks | |
| print("Running basic sanity checks...") | |
| for text in ["hello world", " hello world", "yesterday it was thirty-five degrees"]: | |
| verify(fast_tokenizer, sp_tokenizer, "en", text) | |
| print(" Basic checks passed.") | |
| # Test on XNLI multilingual dataset | |
| print("Loading XNLI dataset...") | |
| xnli = load_dataset("xnli", "all_languages", split="validation") | |
| print("Verifying encode+decode parity on XNLI premises...") | |
| for premise in tqdm(xnli["premise"]): | |
| for lang, text in premise.items(): | |
| verify(fast_tokenizer, sp_tokenizer, lang, text) | |
| print("All checks passed!") | |
| if __name__ == "__main__": | |
| spm_dir = snapshot_download( | |
| repo_id=REPO_ID, | |
| revision=REVISION, | |
| allow_patterns=["tokenizer.model"], | |
| ) | |
| spm_path = str(Path(spm_dir) / "tokenizer.model") | |
| converted_dir = str(Path(__file__).resolve().parent.parent / "tokenizer") | |
| main(spm_path, converted_dir) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment