Created
June 21, 2025 07:04
-
-
Save eustlb/7a9aa6139d11e0103c6b65bac103da52 to your computer and use it in GitHub Desktop.
reproducer for Kyutai stt Transformers integration, test `test_generation`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # ------ install moshi ------ | |
| # git clone https://github.com/kyutai-labs/moshi.git | |
| # cd moshi && git checkout 0395bd6c9a95e899c397a68c75f300f3b5409b2c | |
| # uv pip install -e . | |
| # ---------------------------- | |
| import torch | |
| from moshi import run_inference | |
| args = { | |
| 'tokenizer': None, | |
| 'moshi_weight': None, | |
| 'mimi_weight': None, | |
| 'hf_repo': 'kyutai/stt-2.6b-en', | |
| 'batch_size': 1, | |
| 'device': 'cuda', | |
| 'dtype': torch.float32, | |
| 'config': None, | |
| 'cfg_coef': 1.0, | |
| 'infile': 'bria.mp3', | |
| 'outfile': '' | |
| } | |
| run_inference.seed_all(4242) | |
| checkpoint_info = run_inference.loaders.CheckpointInfo.from_hf_repo( | |
| args['hf_repo'], | |
| args['moshi_weight'], | |
| args['mimi_weight'], | |
| args['tokenizer'], | |
| args['config'] | |
| ) | |
| checkpoint_info.lm_config['context'] = 129 # match transformers default | |
| from datasets import load_dataset, Audio | |
| ds = load_dataset( | |
| "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" | |
| ) | |
| ds = ds.cast_column("audio", Audio(sampling_rate=24000)) | |
| # Only process the first sample | |
| speech_sample = ds.sort("id").select(range(1))[:1]["audio"] | |
| sample = speech_sample[0]["array"] | |
| mimi = checkpoint_info.get_mimi(device=args['device']) | |
| text_tokenizer = checkpoint_info.get_text_tokenizer() | |
| lm = checkpoint_info.get_moshi(device=args['device'], dtype=args['dtype']) | |
| state = run_inference.InferenceState( | |
| checkpoint_info, | |
| mimi, | |
| text_tokenizer, | |
| lm, | |
| args['batch_size'], | |
| args['cfg_coef'], | |
| args['device'], | |
| use_sampling=False, | |
| **checkpoint_info.lm_gen_config | |
| ) | |
| in_pcms = torch.from_numpy(sample).to(device=args['device'], dtype=torch.float32)[None, None, :] | |
| in_pcms = in_pcms.expand(args['batch_size'], -1, -1) | |
| out_items = state.run(in_pcms) | |
| print("Sample 0 output shape:", out_items.shape) | |
| torch.save(out_items, "expected_tokens_generate.pt") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment