Last active
January 21, 2025 13:40
-
-
Save eustlb/d8275c55a5b06f96ac142ff1999099ae to your computer and use it in GitHub Desktop.
Benchmark moonshine/ whisper for varying batch sizes (FLEURS test set)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import evaluate | |
| from transformers.models.whisper.english_normalizer import EnglishTextNormalizer | |
| from transformers import MoonshineForConditionalGeneration, AutoProcessor, WhisperProcessor | |
| from datasets import load_dataset, Audio | |
| from tqdm import tqdm | |
| import json | |
| wer_metric = evaluate.load("wer") | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float32 | |
| model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-tiny", attn_implementation="sdpa").to(device).to(torch_dtype) | |
| processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine-tiny") | |
| dataset = load_dataset("google/fleurs", "en_us", split="test", trust_remote_code=True) | |
| dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) | |
| whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") | |
| normalizer = EnglishTextNormalizer(whisper_processor.tokenizer.english_spelling_normalizer) | |
| start_event = torch.cuda.Event(enable_timing=True) | |
| end_event = torch.cuda.Event(enable_timing=True) | |
| batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] | |
| for batch_size in batch_sizes: | |
| results = { | |
| "batch_size": batch_size, | |
| "labels": [], | |
| "predictions": [], | |
| "durations": [], | |
| "times": [], | |
| "wer": 0, | |
| "rtfx": 0, | |
| } | |
| for samples in tqdm([dataset[i: i+batch_size] for i in range(0, len(dataset), batch_size)]): | |
| torch.cuda.synchronize() | |
| start_event.record() | |
| inputs = processor( | |
| [audio["array"] for audio in samples["audio"]], | |
| return_tensors="pt", | |
| sampling_rate=processor.feature_extractor.sampling_rate, | |
| padding=True, | |
| ) | |
| inputs = inputs.to(device, torch_dtype) | |
| # to avoid hallucination loops, we limit the maximum length of the generated text based expected number of tokens per second | |
| token_limit_factor = 6.5 / processor.feature_extractor.sampling_rate # Maximum of 6.5 tokens per second | |
| seq_lens = inputs.attention_mask.sum(dim=-1) | |
| max_length = int((seq_lens * token_limit_factor).max().item()) | |
| generated_ids = model.generate(**inputs, max_length=max_length) | |
| end_event.record() | |
| torch.cuda.synchronize() | |
| results["times"].append(start_event.elapsed_time(end_event) * 1e-3) | |
| results["labels"].extend(samples["raw_transcription"]) | |
| results["durations"].extend([audio["array"].shape[0] / 16000 for audio in samples["audio"]]) | |
| results["predictions"].extend([processor.decode((tokens), skip_special_tokens=True) for tokens in generated_ids]) | |
| norm_predictions = [normalizer(pred) for pred in results["predictions"]] | |
| norm_predictions = [el for el in norm_predictions if len(el) > 0] | |
| norm_labels = [normalizer(label) for label in results["labels"]] | |
| norm_labels = [el for el in norm_labels if len(el) > 0] | |
| metric = evaluate.load("wer") | |
| wer = 100 * metric.compute(predictions=norm_predictions, references=norm_labels) | |
| audio_length = sum(results["durations"]) | |
| total_time = sum(results["times"]) | |
| rtfx = audio_length / total_time | |
| rtfx = round(rtfx, 2) | |
| results["wer"] = wer | |
| results["rtfx"] = rtfx | |
| with open(f"results_moonshine_tiny_{batch_size}.json", "w") as f: | |
| json.dump(results, f, indent=4) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import evaluate | |
| from transformers.models.whisper.english_normalizer import EnglishTextNormalizer | |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
| from datasets import load_dataset, Audio | |
| from tqdm import tqdm | |
| import json | |
| wer_metric = evaluate.load("wer") | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float32 | |
| model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en").to(device).to(torch_dtype) | |
| processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") | |
| dataset = load_dataset("google/fleurs", "en_us", split="test", trust_remote_code=True) | |
| dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) | |
| whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") | |
| normalizer = EnglishTextNormalizer(whisper_processor.tokenizer.english_spelling_normalizer) | |
| start_event = torch.cuda.Event(enable_timing=True) | |
| end_event = torch.cuda.Event(enable_timing=True) | |
| batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] | |
| for batch_size in batch_sizes: | |
| results = { | |
| "batch_size": batch_size, | |
| "labels": [], | |
| "predictions": [], | |
| "durations": [], | |
| "times": [], | |
| "wer": 0, | |
| "rtfx": 0, | |
| } | |
| for samples in tqdm([dataset[i: i+batch_size] for i in range(0, len(dataset), batch_size)]): | |
| torch.cuda.synchronize() | |
| start_event.record() | |
| inputs = processor( | |
| [audio["array"] for audio in samples["audio"]], | |
| return_tensors="pt", | |
| sampling_rate=processor.feature_extractor.sampling_rate | |
| ) | |
| inputs = inputs.to(device, torch_dtype) | |
| generated_ids = model.generate(**inputs) | |
| end_event.record() | |
| torch.cuda.synchronize() | |
| results["times"].append(start_event.elapsed_time(end_event) * 1e-3) | |
| results["labels"].extend(samples["raw_transcription"]) | |
| results["durations"].extend([audio["array"].shape[0] / 16000 for audio in samples["audio"]]) | |
| results["predictions"].extend([processor.decode((tokens), skip_special_tokens=True) for tokens in generated_ids]) | |
| norm_predictions = [normalizer(pred) for pred in results["predictions"]] | |
| norm_predictions = [el for el in norm_predictions if len(el) > 0] | |
| norm_labels = [normalizer(label) for label in results["labels"]] | |
| norm_labels = [el for el in norm_labels if len(el) > 0] | |
| metric = evaluate.load("wer") | |
| wer = 100 * metric.compute(predictions=norm_predictions, references=norm_labels) | |
| audio_length = sum(results["durations"]) | |
| total_time = sum(results["times"]) | |
| rtfx = audio_length / total_time | |
| rtfx = round(rtfx, 2) | |
| results["wer"] = wer | |
| results["rtfx"] = rtfx | |
| with open(f"results_whisper_tiny_en_{batch_size}.json", "w") as f: | |
| json.dump(results, f, indent=4) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment