Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save eustlb/7a9aa6139d11e0103c6b65bac103da52 to your computer and use it in GitHub Desktop.

Select an option

Save eustlb/7a9aa6139d11e0103c6b65bac103da52 to your computer and use it in GitHub Desktop.
reproducer for Kyutai stt Transformers integration, test `test_generation`
# ------ install moshi ------
# git clone https://github.com/kyutai-labs/moshi.git
# cd moshi && git checkout 0395bd6c9a95e899c397a68c75f300f3b5409b2c
# uv pip install -e .
# ----------------------------
import torch
from moshi import run_inference
args = {
'tokenizer': None,
'moshi_weight': None,
'mimi_weight': None,
'hf_repo': 'kyutai/stt-2.6b-en',
'batch_size': 1,
'device': 'cuda',
'dtype': torch.float32,
'config': None,
'cfg_coef': 1.0,
'infile': 'bria.mp3',
'outfile': ''
}
run_inference.seed_all(4242)
checkpoint_info = run_inference.loaders.CheckpointInfo.from_hf_repo(
args['hf_repo'],
args['moshi_weight'],
args['mimi_weight'],
args['tokenizer'],
args['config']
)
checkpoint_info.lm_config['context'] = 129 # match transformers default
from datasets import load_dataset, Audio
ds = load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
# Only process the first sample
speech_sample = ds.sort("id").select(range(1))[:1]["audio"]
sample = speech_sample[0]["array"]
mimi = checkpoint_info.get_mimi(device=args['device'])
text_tokenizer = checkpoint_info.get_text_tokenizer()
lm = checkpoint_info.get_moshi(device=args['device'], dtype=args['dtype'])
state = run_inference.InferenceState(
checkpoint_info,
mimi,
text_tokenizer,
lm,
args['batch_size'],
args['cfg_coef'],
args['device'],
use_sampling=False,
**checkpoint_info.lm_gen_config
)
in_pcms = torch.from_numpy(sample).to(device=args['device'], dtype=torch.float32)[None, None, :]
in_pcms = in_pcms.expand(args['batch_size'], -1, -1)
out_items = state.run(in_pcms)
print("Sample 0 output shape:", out_items.shape)
torch.save(out_items, "expected_tokens_generate.pt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment