Created
October 6, 2022 13:29
-
-
Save pashu123/0ed1cca1187f0a311e1469eed3ac9967 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from shark.shark_inference import SharkInference | |
import numpy as np | |
p = argparse.ArgumentParser( | |
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter | |
) | |
p.add_argument("--mlir_loc", type=str, default=None, help="the device to use") | |
args = p.parse_args() | |
def load_mlir(mlir_loc): | |
import os | |
if mlir_loc == None: | |
return None | |
print(f"Trying to load the model from {mlir_loc}.") | |
with open(os.path.join(mlir_loc)) as f: | |
mlir_module = f.read() | |
return mlir_module | |
mlir_module = load_mlir(args.mlir_loc) | |
func_name = "forward" | |
shark_module = SharkInference( | |
mlir_module, func_name, device="vulkan", mlir_dialect="mhlo" | |
) | |
shark_module.compile() | |
print("Module compiled successfully!") | |
inputs = np.random.rand(1,64,64,4), np.random.randn(1,320), np.random.randn(1,77,768) | |
inputs = [x.astype(np.float32) for x in inputs] | |
print("Warming up for 2 iterations") | |
for i in range(2): | |
shark_module.forward(inputs) | |
print(f"Warmed for iteration{i+1}") | |
import time | |
for i in range(10): | |
start = time.perf_counter() | |
shark_module.forward(inputs) | |
end = time.perf_counter() | |
print(f"Elapsed time for iteration{i+1}", (end - start) * 10**3, "ms.") | |
# t_records = timeit.repeat(lambda: shark_module.forward(inputs), number=1, repeat=5) | |
# for index, exec_time in enumerate(t_records, 1): | |
# # printing execution time of code in microseconds | |
# # m_secs = round(exec_time * 10 ** 6, 2) | |
# print(f"Iteration {index}: Time Taken: {exec_time}s") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment