Last active
November 19, 2024 23:15
-
-
Save keveman/ea167957fb6364470cb265c5d9aa9da1 to your computer and use it in GitHub Desktop.
moonshine.mojo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from max.engine import ( | |
InputSpec, | |
InferenceSession, | |
Model, | |
TensorMap, | |
NamedTensor, | |
) | |
from pathlib import Path | |
from python import Python, PythonObject | |
from max.tensor import Tensor, TensorSpec, TensorShape | |
import sys | |
from time import perf_counter_ns | |
from python import Python, PythonObject | |
from memory import memcpy | |
from collections import List | |
from max.tensor import Tensor, TensorShape | |
# | |
# To run: | |
# | |
# $ wget https://github.com/usefulsensors/moonshine/raw/refs/heads/main/moonshine/assets/beckett.wav | |
# $ wget https://github.com/usefulsensors/moonshine/raw/refs/heads/main/moonshine/assets/tokenizer.json | |
# $ mojo run moonshine.mojo beckett.wav | |
# | |
@always_inline | |
fn numpy_data_pointer[ | |
type: DType | |
](numpy_array: PythonObject) raises -> UnsafePointer[Scalar[type]]: | |
return numpy_array.__array_interface__["data"][0].unsafe_get_as_pointer[ | |
type | |
]() | |
@always_inline | |
fn memcpy_to_numpy[ | |
type: DType | |
](array: PythonObject, tensor: Tensor[type]) raises: | |
var dst = numpy_data_pointer[type](array) | |
var src = tensor._ptr | |
var length = tensor.num_elements() | |
memcpy(dst.address, src.address, length) | |
@always_inline | |
fn memcpy_from_numpy[ | |
type: DType | |
](array: PythonObject, tensor: Tensor[type]) raises: | |
var src = numpy_data_pointer[type](array) | |
var dst = tensor._ptr | |
var length = tensor.num_elements() | |
memcpy(dst.address, src.address, length) | |
@always_inline | |
fn shape_to_python_list(shape: TensorShape) raises -> PythonObject: | |
var python_list = Python.evaluate("list()") | |
for i in range(shape.rank()): | |
_ = python_list.append(shape[i]) | |
return python_list^ | |
@always_inline | |
fn get_np_dtype[type: DType](np: PythonObject) raises -> PythonObject: | |
@parameter | |
if type is DType.float32: | |
return np.float32 | |
elif type is DType.int32: | |
return np.int32 | |
elif type is DType.int64: | |
return np.int64 | |
elif type is DType.uint8: | |
return np.uint8 | |
raise "Unknown datatype" | |
@always_inline | |
fn tensor_to_numpy[ | |
type: DType | |
](tensor: Tensor[type], np: PythonObject) raises -> PythonObject: | |
var shape = shape_to_python_list(tensor.shape()) | |
var tensor_as_numpy = np.zeros(shape, get_np_dtype[type](np)) | |
_ = shape^ | |
memcpy_to_numpy(tensor_as_numpy, tensor) | |
return tensor_as_numpy^ | |
@always_inline | |
fn numpy_to_tensor[type: DType](array: PythonObject) raises -> Tensor[type]: | |
var shape = List[Int]() | |
var array_shape = array.shape | |
for dim in array_shape: | |
shape.append(dim) | |
var out = Tensor[type](shape) | |
memcpy_from_numpy(array, out) | |
return out^ | |
fn hf_download( | |
hf_hub: PythonObject, filename: String, model_name: String | |
) raises -> Path: | |
repo = "UsefulSensors/moonshine" | |
subfolder = String("onnx/") + model_name | |
return Path( | |
str(hf_hub.hf_hub_download(repo, filename, subfolder=subfolder)) | |
) | |
struct Moonshine: | |
var session: InferenceSession | |
var preprocess: Model | |
var encode: Model | |
var uncached_decode: Model | |
var cached_decode: Model | |
fn __init__(inout self, model_name: String) raises: | |
repo = "UsefulSensors/moonshine" | |
hf_hub = Python.import_module("huggingface_hub") | |
self.session = InferenceSession() | |
self.preprocess = self.session.load( | |
hf_download(hf_hub, "preprocess.onnx", model_name) | |
) | |
self.encode = self.session.load( | |
hf_download(hf_hub, "encode.onnx", model_name) | |
) | |
self.uncached_decode = self.session.load( | |
hf_download(hf_hub, "uncached_decode.onnx", model_name) | |
) | |
self.cached_decode = self.session.load( | |
hf_download(hf_hub, "cached_decode.onnx", model_name) | |
) | |
fn generate(inout self, audio: Tensor[DType.float32]) raises -> List[Int]: | |
var max_len = (audio.shape()[1] / 16000) * 6 | |
var preprocessed_output_name: String = self.preprocess.get_model_output_names()[ | |
0 | |
] | |
var encoded_output_name: String = self.encode.get_model_output_names()[ | |
0 | |
] | |
var preprocessed: Tensor[DType.float32] = self.preprocess.execute( | |
"args_0", audio | |
).get[DType.float32](preprocessed_output_name) | |
var seq_len: Tensor[DType.int32] = (1) | |
seq_len[0] = preprocessed.shape()[1] | |
var context: Tensor[DType.float32] = self.encode.execute( | |
"args_0", preprocessed, "args_1", seq_len | |
).get[DType.float32](encoded_output_name) | |
seq_len[0] = 1 | |
var inputs: Tensor[DType.int32] = (1, 1) | |
inputs[0][0] = 1 | |
var outputs = self.uncached_decode.execute( | |
NamedTensor("args_0", inputs), | |
NamedTensor("args_1", context), | |
NamedTensor("args_2", seq_len), | |
) | |
var output_names: List[ | |
String | |
] = self.uncached_decode.get_model_output_names() | |
var tokens = List[Int]() | |
for i in range(max_len): | |
logits = outputs.get[DType.float32](output_names[0]) | |
token = int(logits.argmax()[0][0]) | |
tokens.append(token) | |
if token == 2: | |
break | |
inputs[0][0] = token | |
seq_len[0] = seq_len[0] + 1 | |
outputs = self.cached_decode.execute( | |
NamedTensor("args_0", inputs), | |
NamedTensor("args_1", context), | |
NamedTensor("args_2", seq_len), | |
NamedTensor( | |
"args_3", outputs.get[DType.float32](output_names[1]) | |
), | |
NamedTensor( | |
"args_4", outputs.get[DType.float32](output_names[2]) | |
), | |
NamedTensor( | |
"args_5", outputs.get[DType.float32](output_names[3]) | |
), | |
NamedTensor( | |
"args_6", outputs.get[DType.float32](output_names[4]) | |
), | |
NamedTensor( | |
"args_7", outputs.get[DType.float32](output_names[5]) | |
), | |
NamedTensor( | |
"args_8", outputs.get[DType.float32](output_names[6]) | |
), | |
NamedTensor( | |
"args_9", outputs.get[DType.float32](output_names[7]) | |
), | |
NamedTensor( | |
"args_10", outputs.get[DType.float32](output_names[8]) | |
), | |
NamedTensor( | |
"args_11", outputs.get[DType.float32](output_names[9]) | |
), | |
NamedTensor( | |
"args_12", outputs.get[DType.float32](output_names[10]) | |
), | |
NamedTensor( | |
"args_13", outputs.get[DType.float32](output_names[11]) | |
), | |
NamedTensor( | |
"args_14", outputs.get[DType.float32](output_names[12]) | |
), | |
NamedTensor( | |
"args_15", outputs.get[DType.float32](output_names[13]) | |
), | |
NamedTensor( | |
"args_16", outputs.get[DType.float32](output_names[14]) | |
), | |
NamedTensor( | |
"args_17", outputs.get[DType.float32](output_names[15]) | |
), | |
NamedTensor( | |
"args_18", outputs.get[DType.float32](output_names[16]) | |
), | |
NamedTensor( | |
"args_19", outputs.get[DType.float32](output_names[17]) | |
), | |
NamedTensor( | |
"args_20", outputs.get[DType.float32](output_names[18]) | |
), | |
NamedTensor( | |
"args_21", outputs.get[DType.float32](output_names[19]) | |
), | |
NamedTensor( | |
"args_22", outputs.get[DType.float32](output_names[20]) | |
), | |
NamedTensor( | |
"args_23", outputs.get[DType.float32](output_names[21]) | |
), | |
NamedTensor( | |
"args_24", outputs.get[DType.float32](output_names[22]) | |
), | |
NamedTensor( | |
"args_25", outputs.get[DType.float32](output_names[23]) | |
), | |
NamedTensor( | |
"args_26", outputs.get[DType.float32](output_names[24]) | |
), | |
) | |
output_names = self.cached_decode.get_model_output_names() | |
return tokens | |
fn main() raises: | |
argv = sys.argv() | |
m = Moonshine("tiny") | |
np = Python.import_module("numpy") | |
wave = Python.import_module("wave") | |
tokenizers = Python.import_module("tokenizers") | |
f = wave.open(argv[1]) | |
params = f.getparams() | |
var audio: PythonObject = np.expand_dims( | |
np.frombuffer(f.readframes(params.nframes), np.int16) / 32768.0, 0 | |
).astype(np.float32) | |
f.close() | |
a = numpy_to_tensor[DType.float32](audio) | |
tokenizer = tokenizers.Tokenizer.from_file("tokenizer.json") | |
tokens = m.generate(a) | |
tokens_tensor = Tensor[DType.int32](len(tokens)) | |
for i in range(len(tokens)): | |
tokens_tensor[i] = tokens[i] | |
decoded = tokenizer.decode(tensor_to_numpy(tokens_tensor, np)) | |
print(decoded) | |
print("warmup..") | |
for _ in range(4): | |
_ = m.generate(a) | |
N = 4 | |
print("timing..") | |
start_time = perf_counter_ns() | |
for _ in range(N): | |
_ = m.generate(a) | |
end_time = perf_counter_ns() | |
full_duration = (end_time - start_time) / 1e6 / N | |
print(full_duration, "ms") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment