Skip to content

Instantly share code, notes, and snippets.

@eustlb
Created September 2, 2025 12:19
Show Gist options
  • Select an option

  • Save eustlb/45fac784f7f7244c4f1f90e239eae00e to your computer and use it in GitHub Desktop.

Select an option

Save eustlb/45fac784f7f7244c4f1f90e239eae00e to your computer and use it in GitHub Desktop.
reproducer_voxtral_mini_wer_librispeech
from datasets import load_dataset, Audio
from transformers import VoxtralForConditionalGeneration, VoxtralProcessor
import os
import torch
from whisper.normalizers import EnglishTextNormalizer
import jiwer
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32
MODEL_ID = "mistralai/Voxtral-Mini-3B-2507"
test_set = load_dataset("hf-audio/esb-datasets-test-only-sorted", "librispeech", split="test.clean")
test_set = test_set.cast_column("audio", Audio(sampling_rate=16000))
processor = VoxtralProcessor.from_pretrained(MODEL_ID)
model = VoxtralForConditionalGeneration.from_pretrained(MODEL_ID, device_map=torch_device, torch_dtype=torch.bfloat16)
def eval_batch(batch):
inputs = processor.apply_transcription_request(
language="en", audio=[el['array'] for el in batch["audio"]], format=[el.metadata.codec.upper() for el in batch["audio"]], model_id=MODEL_ID
)
inputs.to(torch_device, dtype=torch.bfloat16)
outputs = model.generate(**inputs, max_new_tokens=10000)
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
return {
"references": batch["text"],
"predictions": decoded_outputs,
}
infered_test_set = test_set.map(eval_batch, batched=True, batch_size=BATCH_SIZE, remove_columns=test_set.column_names)
infered_test_set.save_to_disk("infered_test_set")
normalizer = EnglishTextNormalizer()
normalized_refs = [normalizer(ref) for ref in infered_test_set["references"]]
normalized_hyps = [normalizer(hyp) for hyp in infered_test_set["predictions"]]
sum_wer = sum(jiwer.wer(ref, hyp) for ref, hyp in zip(normalized_refs, normalized_hyps))
print(f"mean WER: {sum_wer / len(infered_test_set)}")
print(f"Courpus WER: {jiwer.wer(normalized_refs, normalized_hyps)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment