Skip to content

Instantly share code, notes, and snippets.

@petewarden
Last active January 2, 2025 22:26
Show Gist options
  • Save petewarden/09a17d2ded03d24e445c7e7681517ee9 to your computer and use it in GitHub Desktop.
Save petewarden/09a17d2ded03d24e445c7e7681517ee9 to your computer and use it in GitHub Desktop.
Script to calculate the word error rate of Moonshine models
import argparse
from typing import List, Tuple
import numpy as np
from datasets import load_dataset as load_ds
from jiwer import wer
from model import MoonshineOnnxModel
from transcribe import load_tokenizer
from tqdm import tqdm
from whisper.normalizers import EnglishTextNormalizer
def calculate_wer(model_name: str, models_dir: str = None):
"""Calculate Word Error Rate for the given model using Librispeech ASR dataset."""
# Use copy of dataset test split to avoid download of full dataset (30GB).
dataset = load_ds(
path="hf-audio/esb-datasets-test-only-sorted",
name="librispeech",
split="test.clean",
trust_remote_code=True,
)
model = MoonshineOnnxModel(model_name=model_name, models_dir=models_dir)
normalizer = EnglishTextNormalizer()
tokenizer = load_tokenizer()
expected_texts, predicted_texts = process_dataset(dataset, model, tokenizer)
return wer(
normalizer(" ".join(expected_texts)),
normalizer(" ".join(predicted_texts)),
)
def process_dataset(
dataset, model: MoonshineOnnxModel, tokenizer
) -> Tuple[List[str], List[str]]:
"""Process the dataset and return list pair of expected and predicted text."""
expected_texts, predicted_texts = [], []
i = 0
for example in tqdm(dataset):
audio = example["audio"]["array"]
audio_input = audio[np.newaxis, :].astype(np.float32)
tokens = model.generate(audio_input)
predicted_text = tokenizer.decode_batch(tokens)[0]
expected_texts.append(" " + example["text"])
predicted_texts.append(" " + predicted_text)
if not predicted_text:
tqdm.write(f"Model predicted an empty text for example {i}")
i += 1
return expected_texts, predicted_texts
def parse_arguments() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
prog="wer.py",
description="Word Error Rate test for Moonshine models with Librispeech ASR",
)
parser.add_argument(
"--model_name",
help="Model to run the WER test with",
default="moonshine/tiny",
choices=["moonshine/base", "moonshine/tiny"],
)
parser.add_argument(
"--models_dir",
help="Folder containing local model files",
default=None,
)
result = parser.parse_args()
return result
if __name__ == "__main__":
args = parse_arguments()
wer_result = calculate_wer(args.model_name, args.models_dir)
print(f"\n Model: {args.model_name} {args.models_dir}")
print(f" WER: {100. * wer_result:.2f}% using OpenAI Whisper EnglishTextNormalizer")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment