Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created October 6, 2022 13:29
Show Gist options
  • Save pashu123/0ed1cca1187f0a311e1469eed3ac9967 to your computer and use it in GitHub Desktop.
Save pashu123/0ed1cca1187f0a311e1469eed3ac9967 to your computer and use it in GitHub Desktop.
import argparse
from shark.shark_inference import SharkInference
import numpy as np
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
p.add_argument("--mlir_loc", type=str, default=None, help="the device to use")
args = p.parse_args()
def load_mlir(mlir_loc):
import os
if mlir_loc == None:
return None
print(f"Trying to load the model from {mlir_loc}.")
with open(os.path.join(mlir_loc)) as f:
mlir_module = f.read()
return mlir_module
mlir_module = load_mlir(args.mlir_loc)
func_name = "forward"
shark_module = SharkInference(
mlir_module, func_name, device="vulkan", mlir_dialect="mhlo"
)
shark_module.compile()
print("Module compiled successfully!")
inputs = np.random.rand(1,64,64,4), np.random.randn(1,320), np.random.randn(1,77,768)
inputs = [x.astype(np.float32) for x in inputs]
print("Warming up for 2 iterations")
for i in range(2):
shark_module.forward(inputs)
print(f"Warmed for iteration{i+1}")
import time
for i in range(10):
start = time.perf_counter()
shark_module.forward(inputs)
end = time.perf_counter()
print(f"Elapsed time for iteration{i+1}", (end - start) * 10**3, "ms.")
# t_records = timeit.repeat(lambda: shark_module.forward(inputs), number=1, repeat=5)
# for index, exec_time in enumerate(t_records, 1):
# # printing execution time of code in microseconds
# # m_secs = round(exec_time * 10 ** 6, 2)
# print(f"Iteration {index}: Time Taken: {exec_time}s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment