Created
February 20, 2026 15:18
-
-
Save eustlb/34f79f34d423ccf8983c2c6c8dab2bcc to your computer and use it in GitHub Desktop.
Reproduce expected outputs for test_integration_longform in transformers/tests/models/mimi/test_modeling_mimi.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Reproduce expected outputs for test_integration_longform in | |
| # transformers/tests/models/mimi/test_modeling_mimi.py | |
| # | |
| # This uses the original moshi codebase (https://github.com/kyutai-labs/moshi) | |
| # to generate reference values. | |
| # | |
| # Installation: | |
| # git clone https://github.com/kyutai-labs/moshi.git | |
| # uv pip install -e moshi/moshi/ | |
| # uv pip install librosa | |
| # | |
| # Usage: | |
| # python reproduce_outputs_test_integration_longform.py | |
| import torch | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| from moshi.models import loaders | |
| from transformers.audio_utils import load_audio | |
| def normalize(arr): | |
| return arr / np.linalg.norm(arr) | |
| def compute_rmse(arr1, arr2): | |
| return np.sqrt(((normalize(arr1) - normalize(arr2)) ** 2).mean()) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Load audio | |
| audio = load_audio( | |
| "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3", | |
| sampling_rate=24000, | |
| ) | |
| wav = torch.from_numpy(audio).float().unsqueeze(0).unsqueeze(0).to(device) # [1, 1, T] | |
| input_len = wav.shape[-1] | |
| # Load Mimi from original codebase | |
| mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME) | |
| mimi = loaders.get_mimi(mimi_weight, device=device) | |
| with torch.no_grad(): | |
| for n_q in [8, 32]: | |
| mimi.set_num_codebooks(n_q) | |
| codes = mimi.encode(wav) | |
| decoded = mimi.decode(codes) | |
| decoded = decoded[..., :input_len] | |
| arr = wav[0].cpu().numpy() | |
| arr_dec = decoded[0].cpu().numpy() | |
| rmse = compute_rmse(arr, arr_dec) | |
| print(f"num_codebooks={n_q}:") | |
| print(f" codes shape: {codes.shape}") | |
| print(f" codes sum: {codes.sum().item()}") | |
| print(f" rmse: {rmse}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment