Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created July 25, 2022 15:22
Show Gist options
  • Save pashu123/07a6ace20a783429b6cb23d72a83fe1e to your computer and use it in GitHub Desktop.
Save pashu123/07a6ace20a783429b6cb23d72a83fe1e to your computer and use it in GitHub Desktop.
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_torch_model
import numpy as np
mlir_model, func_name, inputs, golden_out = download_torch_model("v_diffusion")
shark_module = SharkInference(
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward(inputs)
print("The obtained result via shark is: ", result)
print("The golden result is:", golden_out)
tuple_gold = golden_out[0].reshape(2,3,256,256)
np.testing.assert_allclose(tuple_gold, result, rtol=1e-02, atol=1e-03)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment