Skip to content

Instantly share code, notes, and snippets.

@eustlb
Created February 20, 2026 15:18
Show Gist options
  • Select an option

  • Save eustlb/34f79f34d423ccf8983c2c6c8dab2bcc to your computer and use it in GitHub Desktop.

Select an option

Save eustlb/34f79f34d423ccf8983c2c6c8dab2bcc to your computer and use it in GitHub Desktop.
Reproduce expected outputs for test_integration_longform in transformers/tests/models/mimi/test_modeling_mimi.py
# Reproduce expected outputs for test_integration_longform in
# transformers/tests/models/mimi/test_modeling_mimi.py
#
# This uses the original moshi codebase (https://github.com/kyutai-labs/moshi)
# to generate reference values.
#
# Installation:
# git clone https://github.com/kyutai-labs/moshi.git
# uv pip install -e moshi/moshi/
# uv pip install librosa
#
# Usage:
# python reproduce_outputs_test_integration_longform.py
import torch
import numpy as np
from huggingface_hub import hf_hub_download
from moshi.models import loaders
from transformers.audio_utils import load_audio
def normalize(arr):
return arr / np.linalg.norm(arr)
def compute_rmse(arr1, arr2):
return np.sqrt(((normalize(arr1) - normalize(arr2)) ** 2).mean())
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load audio
audio = load_audio(
"https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3",
sampling_rate=24000,
)
wav = torch.from_numpy(audio).float().unsqueeze(0).unsqueeze(0).to(device) # [1, 1, T]
input_len = wav.shape[-1]
# Load Mimi from original codebase
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device=device)
with torch.no_grad():
for n_q in [8, 32]:
mimi.set_num_codebooks(n_q)
codes = mimi.encode(wav)
decoded = mimi.decode(codes)
decoded = decoded[..., :input_len]
arr = wav[0].cpu().numpy()
arr_dec = decoded[0].cpu().numpy()
rmse = compute_rmse(arr, arr_dec)
print(f"num_codebooks={n_q}:")
print(f" codes shape: {codes.shape}")
print(f" codes sum: {codes.sum().item()}")
print(f" rmse: {rmse}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment