Skip to content

Instantly share code, notes, and snippets.

@pashu123
Last active October 13, 2022 17:42
Show Gist options
  • Save pashu123/2e9d2c3f4bddd8752746ce494695039e to your computer and use it in GitHub Desktop.
Save pashu123/2e9d2c3f4bddd8752746ce494695039e to your computer and use it in GitHub Desktop.
import numpy as np
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.shark_downloader import download_torch_model
mlir_model, func_name, inputs, golden_out = download_torch_model(
"resnet_50_fp16_torch", tank_url="gs://shark_tank/prashant_nod"
)
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward(inputs)
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment