Last active
November 19, 2024 23:16
-
-
Save keveman/d2aea1a059c9a14972783ede2d6b6862 to your computer and use it in GitHub Desktop.
moonshine.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import argparse | |
import wave | |
import time | |
import numpy as np | |
import tokenizers | |
# | |
# To run: | |
# $ wget https://github.com/usefulsensors/moonshine/raw/refs/heads/main/moonshine/assets/beckett.wav | |
# $ wget https://github.com/usefulsensors/moonshine/raw/refs/heads/main/moonshine/assets/tokenizer.json | |
# $ python moonshine.py --model_name 'moonshine/tiny' --wav_file beckett.wav | |
# | |
def _get_onnx_weights(model_name): | |
from huggingface_hub import hf_hub_download | |
repo = "UsefulSensors/moonshine" | |
return ( | |
hf_hub_download(repo, f"{x}.onnx", subfolder=f"onnx/{model_name}") | |
for x in ("preprocess", "encode", "uncached_decode", "cached_decode") | |
) | |
class MoonshineMaxModel(object): | |
def __init__(self, models_dir=None, model_name=None): | |
from max import engine | |
from max.dtype import DType | |
if models_dir is None: | |
assert ( | |
model_name is not None | |
), "model_name should be specified if models_dir is not" | |
preprocess, encode, uncached_decode, cached_decode = ( | |
self._load_weights_from_hf_hub(model_name) | |
) | |
else: | |
preprocess, encode, uncached_decode, cached_decode = [ | |
f"{models_dir}/{x}.onnx" | |
for x in ["preprocess", "encode", "uncached_decode", "cached_decode"] | |
] | |
self.engine = engine.InferenceSession() | |
self.preprocess = self.engine.load(preprocess) | |
self.encode = self.engine.load(encode) | |
self.uncached_decode = self.engine.load(uncached_decode) | |
self.cached_decode = self.engine.load(cached_decode) | |
def _load_weights_from_hf_hub(self, model_name): | |
model_name = model_name.split("/")[-1] | |
return _get_onnx_weights(model_name) | |
def generate(self, audio, max_len=None): | |
"audio has to be a numpy array of shape [1, num_audio_samples]" | |
if max_len is None: | |
# max 6 tokens per second of audio | |
max_len = int((audio.shape[-1] / 16_000) * 6) | |
preprocessed = self.preprocess.execute(args_0=audio)[ | |
self.preprocess.output_metadata[0].name | |
] | |
seq_len = np.array([preprocessed.shape[-2]], dtype=np.int32) | |
context = self.encode.execute(args_0=preprocessed, args_1=seq_len)[ | |
self.encode.output_metadata[0].name | |
] | |
inputs = np.array([[1]], dtype=np.int32) | |
seq_len = np.array([1], dtype=np.int32) | |
tokens = [1] | |
outputs = self.uncached_decode.execute( | |
args_0=inputs, args_1=context, args_2=seq_len | |
) | |
logits, *cache = [outputs[x.name] for x in self.uncached_decode.output_metadata] | |
for i in range(max_len): | |
next_token = logits.squeeze().argmax() | |
tokens.extend([next_token]) | |
if next_token == 2: | |
break | |
seq_len[0] += 1 | |
inputs = np.array([[next_token]], dtype=np.int32) | |
outputs = self.cached_decode.execute( | |
**dict( | |
args_0=inputs, | |
args_1=context, | |
args_2=seq_len, | |
**{f"args_{i+3}": x for i, x in enumerate(cache)}, | |
), | |
) | |
logits, *cache = [ | |
outputs[x.name] for x in self.cached_decode.output_metadata | |
] | |
return [tokens] | |
class MoonshineOnnxModel(object): | |
def __init__(self, models_dir=None, model_name=None): | |
import onnxruntime | |
if models_dir is None: | |
assert ( | |
model_name is not None | |
), "model_name should be specified if models_dir is not" | |
preprocess, encode, uncached_decode, cached_decode = ( | |
self._load_weights_from_hf_hub(model_name) | |
) | |
else: | |
preprocess, encode, uncached_decode, cached_decode = [ | |
f"{models_dir}/{x}.onnx" | |
for x in ["preprocess", "encode", "uncached_decode", "cached_decode"] | |
] | |
self.preprocess = onnxruntime.InferenceSession(preprocess) | |
self.encode = onnxruntime.InferenceSession(encode) | |
self.uncached_decode = onnxruntime.InferenceSession(uncached_decode) | |
self.cached_decode = onnxruntime.InferenceSession(cached_decode) | |
def _load_weights_from_hf_hub(self, model_name): | |
model_name = model_name.split("/")[-1] | |
return _get_onnx_weights(model_name) | |
def generate(self, audio, max_len=None): | |
"audio has to be a numpy array of shape [1, num_audio_samples]" | |
if max_len is None: | |
# max 6 tokens per second of audio | |
max_len = int((audio.shape[-1] / 16_000) * 6) | |
preprocessed = self.preprocess.run([], dict(args_0=audio))[0] | |
seq_len = [preprocessed.shape[-2]] | |
context = self.encode.run([], dict(args_0=preprocessed, args_1=seq_len))[0] | |
inputs = [[1]] | |
seq_len = [1] | |
tokens = [1] | |
logits, *cache = self.uncached_decode.run( | |
[], dict(args_0=inputs, args_1=context, args_2=seq_len) | |
) | |
for i in range(max_len): | |
next_token = logits.squeeze().argmax() | |
tokens.extend([next_token]) | |
if next_token == 2: | |
break | |
seq_len[0] += 1 | |
inputs = [[next_token]] | |
logits, *cache = self.cached_decode.run( | |
[], | |
dict( | |
args_0=inputs, | |
args_1=context, | |
args_2=seq_len, | |
**{f"args_{i+3}": x for i, x in enumerate(cache)}, | |
), | |
) | |
return [tokens] | |
def main(models_dir, wav_file, model_name): | |
m = ( | |
MoonshineOnnxModel(models_dir=models_dir) | |
if models_dir | |
else MoonshineOnnxModel(model_name=model_name) | |
) | |
with wave.open(wav_file) as f: | |
params = f.getparams() | |
assert ( | |
params.nchannels == 1 | |
and params.framerate == 16_000 | |
and params.sampwidth == 2 | |
), f"wave file should have 1 channel, 16KHz, and int16" | |
audio = f.readframes(params.nframes) | |
audio = np.frombuffer(audio, np.int16) / 32768.0 | |
audio = audio.astype(np.float32)[None, ...] | |
print("warmup...") | |
for _ in range(4): | |
tokens = m.generate(audio) | |
N = 4 | |
start = time.time_ns() | |
for _ in range(N): | |
tokens = m.generate(audio) | |
end = time.time_ns() | |
inference_time = (end - start) / N / 1e6 | |
print(f"Time per inference = {inference_time:.2f}ms") | |
tokenizer = tokenizers.Tokenizer.from_file("tokenizer.json") | |
text = tokenizer.decode_batch(tokens) | |
print(text) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
prog="onnx_standalone", | |
description="Standalone ONNX demo of Moonshine models", | |
) | |
group = parser.add_mutually_exclusive_group(required=True) | |
group.add_argument("--models_dir", help="Directory containing ONNX files") | |
group.add_argument("--model_name", help="Name of moonshine model") | |
parser.add_argument("--wav_file", help="Speech WAV file") | |
args = parser.parse_args() | |
main(**vars(args)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment