Created
February 10, 2026 14:49
-
-
Save eustlb/980bade49311336509985f9a308e80af to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # SPDX-License-Identifier: Apache-2.0 | |
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | |
| """ | |
| Reproduce expected outputs for each VoxtralRealtime HF integration test. | |
| Uses vLLM offline inference (as in run_eval.py) to generate reference | |
| transcriptions for every @slow integration test in | |
| test_modeling_voxtral_realtime.py, then saves them to a JSON file. | |
| """ | |
| import json | |
| import os | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "1" | |
| from dataclasses import asdict | |
| import datasets | |
| import numpy as np | |
| from datasets import Audio, load_dataset | |
| from mistral_common.audio import Audio as MistralAudio | |
| from mistral_common.protocol.instruct.chunk import RawAudio | |
| from mistral_common.protocol.transcription.request import ( | |
| StreamingMode, | |
| TranscriptionRequest, | |
| ) | |
| from mistral_common.tokens.tokenizers.mistral import MistralTokenizer | |
| from transformers.audio_utils import load_audio | |
| from vllm import LLM, EngineArgs, SamplingParams | |
| MODEL_NAME = "mistralai/Voxtral-Mini-4B-Realtime-2602" | |
| SAMPLING_RATE = 16_000 | |
| ENGINE_CONFIG = dict( | |
| model=MODEL_NAME, | |
| max_model_len=8192, | |
| max_num_seqs=8, | |
| limit_mm_per_prompt={"audio": 1}, | |
| config_format="mistral", | |
| load_format="mistral", | |
| tokenizer_mode="mistral", | |
| enforce_eager=True, | |
| gpu_memory_utilization=0.9, | |
| ) | |
| # Audio URLs used in the integration tests | |
| AUDIO_URL_DUDE = "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/dude_where_is_my_car.wav" | |
| AUDIO_URL_OBAMA = "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama.mp3" | |
| OUTPUT_FILE = "expected_outputs.json" | |
| # ── helpers ────────────────────────────────────────────────────────────────── | |
| def _build_engine() -> LLM: | |
| engine_args = EngineArgs(**ENGINE_CONFIG) | |
| return LLM(**asdict(engine_args)) | |
| def _tokenize_audio(tokenizer: MistralTokenizer, audio_array: np.ndarray): | |
| """Tokenize a single audio sample via mistral_common (mirrors run_eval.py).""" | |
| audio = MistralAudio( | |
| audio_array=audio_array, sampling_rate=SAMPLING_RATE, format="wav" | |
| ) | |
| req = TranscriptionRequest( | |
| audio=RawAudio.from_audio(audio), | |
| streaming=StreamingMode.OFFLINE, | |
| language=None, | |
| ) | |
| tokenized = tokenizer.encode_transcription(req) | |
| return tokenized.tokens, tokenized.audios[0].audio_array | |
| def _generate_batch( | |
| llm: LLM, | |
| tokenizer: MistralTokenizer, | |
| audio_arrays: list[np.ndarray], | |
| ) -> list[str]: | |
| """Run vLLM offline generation on a list of audio arrays and return decoded strings.""" | |
| audio_config = tokenizer.instruct_tokenizer.tokenizer.audio | |
| inputs = [] | |
| sampling_params_list = [] | |
| for audio_array in audio_arrays: | |
| tokens, encoded_audio = _tokenize_audio(tokenizer, audio_array) | |
| num_samples = encoded_audio.shape[0] | |
| max_tokens = audio_config.num_audio_tokens(num_samples) - len(tokens) - 1 | |
| inputs.append( | |
| { | |
| "prompt_token_ids": tokens, | |
| "multi_modal_data": {"audio": [(encoded_audio, None)]}, | |
| } | |
| ) | |
| sampling_params_list.append( | |
| SamplingParams(temperature=0.0, max_tokens=max_tokens) | |
| ) | |
| outputs = llm.generate(inputs, sampling_params=sampling_params_list) | |
| return [tokenizer.decode(out.outputs[0].token_ids) for out in outputs] | |
| # ── test-case reproducers ──────────────────────────────────────────────────── | |
| def reproduce_test_single_longform(llm, tokenizer) -> list[str]: | |
| """Mirrors VoxtralRealtimeForConditionalGenerationIntegrationTest.test_single_longform""" | |
| print("\n" + "=" * 60) | |
| print("Reproducing: test_single_longform") | |
| print("=" * 60) | |
| audio = load_audio(AUDIO_URL_DUDE, SAMPLING_RATE) | |
| predictions = _generate_batch(llm, tokenizer, [audio]) | |
| for i, pred in enumerate(predictions): | |
| print(f" [{i}]: {pred}") | |
| return predictions | |
| def reproduce_test_batched(llm, tokenizer) -> list[str]: | |
| """Mirrors VoxtralRealtimeForConditionalGenerationIntegrationTest.test_batched""" | |
| print("\n" + "=" * 60) | |
| print("Reproducing: test_batched") | |
| print("=" * 60) | |
| ds = datasets.load_dataset( | |
| "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" | |
| ) | |
| speech_samples = ds.sort("id")[:5]["audio"] | |
| audio_arrays = [x["array"] for x in speech_samples] | |
| predictions = _generate_batch(llm, tokenizer, audio_arrays) | |
| for i, pred in enumerate(predictions): | |
| print(f" [{i}]: {pred}") | |
| return predictions | |
| def reproduce_test_batched_longform(llm, tokenizer) -> list[str]: | |
| """Mirrors VoxtralRealtimeForConditionalGenerationIntegrationTest.test_batched_longform""" | |
| print("\n" + "=" * 60) | |
| print("Reproducing: test_batched_longform") | |
| print("=" * 60) | |
| audio1 = load_audio(AUDIO_URL_DUDE, SAMPLING_RATE) | |
| audio2 = load_audio(AUDIO_URL_OBAMA, SAMPLING_RATE) | |
| predictions = _generate_batch(llm, tokenizer, [audio1, audio2]) | |
| for i, pred in enumerate(predictions): | |
| print(f" [{i}]: {pred}") | |
| return predictions | |
| # ── main ───────────────────────────────────────────────────────────────────── | |
| def main(): | |
| tokenizer = MistralTokenizer.from_hf_hub(MODEL_NAME) | |
| llm = _build_engine() | |
| results = {} | |
| results["test_single_longform"] = reproduce_test_single_longform(llm, tokenizer) | |
| results["test_batched"] = reproduce_test_batched(llm, tokenizer) | |
| results["test_batched_longform"] = reproduce_test_batched_longform(llm, tokenizer) | |
| with open(OUTPUT_FILE, "w") as f: | |
| json.dump(results, f, indent=2) | |
| # ── summary ────────────────────────────────────────────────────────── | |
| print(f"\n{'=' * 60}") | |
| print(f"All expected outputs saved to {OUTPUT_FILE}") | |
| print(f"{'=' * 60}") | |
| for test_name, outputs in results.items(): | |
| print(f"\n {test_name} ({len(outputs)} output(s)):") | |
| for i, text in enumerate(outputs): | |
| preview = text[:120] + ("..." if len(text) > 120 else "") | |
| print(f" [{i}]: {preview}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment