Skip to content

Instantly share code, notes, and snippets.

@keveman
Last active November 19, 2024 23:16
Show Gist options
  • Save keveman/d2aea1a059c9a14972783ede2d6b6862 to your computer and use it in GitHub Desktop.
Save keveman/d2aea1a059c9a14972783ede2d6b6862 to your computer and use it in GitHub Desktop.
moonshine.py
import os
import sys
import argparse
import wave
import time
import numpy as np
import tokenizers
#
# To run:
# $ wget https://github.com/usefulsensors/moonshine/raw/refs/heads/main/moonshine/assets/beckett.wav
# $ wget https://github.com/usefulsensors/moonshine/raw/refs/heads/main/moonshine/assets/tokenizer.json
# $ python moonshine.py --model_name 'moonshine/tiny' --wav_file beckett.wav
#
def _get_onnx_weights(model_name):
from huggingface_hub import hf_hub_download
repo = "UsefulSensors/moonshine"
return (
hf_hub_download(repo, f"{x}.onnx", subfolder=f"onnx/{model_name}")
for x in ("preprocess", "encode", "uncached_decode", "cached_decode")
)
class MoonshineMaxModel(object):
def __init__(self, models_dir=None, model_name=None):
from max import engine
from max.dtype import DType
if models_dir is None:
assert (
model_name is not None
), "model_name should be specified if models_dir is not"
preprocess, encode, uncached_decode, cached_decode = (
self._load_weights_from_hf_hub(model_name)
)
else:
preprocess, encode, uncached_decode, cached_decode = [
f"{models_dir}/{x}.onnx"
for x in ["preprocess", "encode", "uncached_decode", "cached_decode"]
]
self.engine = engine.InferenceSession()
self.preprocess = self.engine.load(preprocess)
self.encode = self.engine.load(encode)
self.uncached_decode = self.engine.load(uncached_decode)
self.cached_decode = self.engine.load(cached_decode)
def _load_weights_from_hf_hub(self, model_name):
model_name = model_name.split("/")[-1]
return _get_onnx_weights(model_name)
def generate(self, audio, max_len=None):
"audio has to be a numpy array of shape [1, num_audio_samples]"
if max_len is None:
# max 6 tokens per second of audio
max_len = int((audio.shape[-1] / 16_000) * 6)
preprocessed = self.preprocess.execute(args_0=audio)[
self.preprocess.output_metadata[0].name
]
seq_len = np.array([preprocessed.shape[-2]], dtype=np.int32)
context = self.encode.execute(args_0=preprocessed, args_1=seq_len)[
self.encode.output_metadata[0].name
]
inputs = np.array([[1]], dtype=np.int32)
seq_len = np.array([1], dtype=np.int32)
tokens = [1]
outputs = self.uncached_decode.execute(
args_0=inputs, args_1=context, args_2=seq_len
)
logits, *cache = [outputs[x.name] for x in self.uncached_decode.output_metadata]
for i in range(max_len):
next_token = logits.squeeze().argmax()
tokens.extend([next_token])
if next_token == 2:
break
seq_len[0] += 1
inputs = np.array([[next_token]], dtype=np.int32)
outputs = self.cached_decode.execute(
**dict(
args_0=inputs,
args_1=context,
args_2=seq_len,
**{f"args_{i+3}": x for i, x in enumerate(cache)},
),
)
logits, *cache = [
outputs[x.name] for x in self.cached_decode.output_metadata
]
return [tokens]
class MoonshineOnnxModel(object):
def __init__(self, models_dir=None, model_name=None):
import onnxruntime
if models_dir is None:
assert (
model_name is not None
), "model_name should be specified if models_dir is not"
preprocess, encode, uncached_decode, cached_decode = (
self._load_weights_from_hf_hub(model_name)
)
else:
preprocess, encode, uncached_decode, cached_decode = [
f"{models_dir}/{x}.onnx"
for x in ["preprocess", "encode", "uncached_decode", "cached_decode"]
]
self.preprocess = onnxruntime.InferenceSession(preprocess)
self.encode = onnxruntime.InferenceSession(encode)
self.uncached_decode = onnxruntime.InferenceSession(uncached_decode)
self.cached_decode = onnxruntime.InferenceSession(cached_decode)
def _load_weights_from_hf_hub(self, model_name):
model_name = model_name.split("/")[-1]
return _get_onnx_weights(model_name)
def generate(self, audio, max_len=None):
"audio has to be a numpy array of shape [1, num_audio_samples]"
if max_len is None:
# max 6 tokens per second of audio
max_len = int((audio.shape[-1] / 16_000) * 6)
preprocessed = self.preprocess.run([], dict(args_0=audio))[0]
seq_len = [preprocessed.shape[-2]]
context = self.encode.run([], dict(args_0=preprocessed, args_1=seq_len))[0]
inputs = [[1]]
seq_len = [1]
tokens = [1]
logits, *cache = self.uncached_decode.run(
[], dict(args_0=inputs, args_1=context, args_2=seq_len)
)
for i in range(max_len):
next_token = logits.squeeze().argmax()
tokens.extend([next_token])
if next_token == 2:
break
seq_len[0] += 1
inputs = [[next_token]]
logits, *cache = self.cached_decode.run(
[],
dict(
args_0=inputs,
args_1=context,
args_2=seq_len,
**{f"args_{i+3}": x for i, x in enumerate(cache)},
),
)
return [tokens]
def main(models_dir, wav_file, model_name):
m = (
MoonshineOnnxModel(models_dir=models_dir)
if models_dir
else MoonshineOnnxModel(model_name=model_name)
)
with wave.open(wav_file) as f:
params = f.getparams()
assert (
params.nchannels == 1
and params.framerate == 16_000
and params.sampwidth == 2
), f"wave file should have 1 channel, 16KHz, and int16"
audio = f.readframes(params.nframes)
audio = np.frombuffer(audio, np.int16) / 32768.0
audio = audio.astype(np.float32)[None, ...]
print("warmup...")
for _ in range(4):
tokens = m.generate(audio)
N = 4
start = time.time_ns()
for _ in range(N):
tokens = m.generate(audio)
end = time.time_ns()
inference_time = (end - start) / N / 1e6
print(f"Time per inference = {inference_time:.2f}ms")
tokenizer = tokenizers.Tokenizer.from_file("tokenizer.json")
text = tokenizer.decode_batch(tokens)
print(text)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="onnx_standalone",
description="Standalone ONNX demo of Moonshine models",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--models_dir", help="Directory containing ONNX files")
group.add_argument("--model_name", help="Name of moonshine model")
parser.add_argument("--wav_file", help="Speech WAV file")
args = parser.parse_args()
main(**vars(args))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment