Created
September 2, 2025 12:19
-
-
Save eustlb/45fac784f7f7244c4f1f90e239eae00e to your computer and use it in GitHub Desktop.
reproducer_voxtral_mini_wer_librispeech
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from datasets import load_dataset, Audio | |
| from transformers import VoxtralForConditionalGeneration, VoxtralProcessor | |
| import os | |
| import torch | |
| from whisper.normalizers import EnglishTextNormalizer | |
| import jiwer | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
| torch_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| BATCH_SIZE = 32 | |
| MODEL_ID = "mistralai/Voxtral-Mini-3B-2507" | |
| test_set = load_dataset("hf-audio/esb-datasets-test-only-sorted", "librispeech", split="test.clean") | |
| test_set = test_set.cast_column("audio", Audio(sampling_rate=16000)) | |
| processor = VoxtralProcessor.from_pretrained(MODEL_ID) | |
| model = VoxtralForConditionalGeneration.from_pretrained(MODEL_ID, device_map=torch_device, torch_dtype=torch.bfloat16) | |
| def eval_batch(batch): | |
| inputs = processor.apply_transcription_request( | |
| language="en", audio=[el['array'] for el in batch["audio"]], format=[el.metadata.codec.upper() for el in batch["audio"]], model_id=MODEL_ID | |
| ) | |
| inputs.to(torch_device, dtype=torch.bfloat16) | |
| outputs = model.generate(**inputs, max_new_tokens=10000) | |
| decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| return { | |
| "references": batch["text"], | |
| "predictions": decoded_outputs, | |
| } | |
| infered_test_set = test_set.map(eval_batch, batched=True, batch_size=BATCH_SIZE, remove_columns=test_set.column_names) | |
| infered_test_set.save_to_disk("infered_test_set") | |
| normalizer = EnglishTextNormalizer() | |
| normalized_refs = [normalizer(ref) for ref in infered_test_set["references"]] | |
| normalized_hyps = [normalizer(hyp) for hyp in infered_test_set["predictions"]] | |
| sum_wer = sum(jiwer.wer(ref, hyp) for ref, hyp in zip(normalized_refs, normalized_hyps)) | |
| print(f"mean WER: {sum_wer / len(infered_test_set)}") | |
| print(f"Courpus WER: {jiwer.wer(normalized_refs, normalized_hyps)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment