Created
February 6, 2024 15:15
-
-
Save fxmarty/3810931be5c18a4ea648e71d9e41e082 to your computer and use it in GitHub Desktop.
Repro opt vits
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 1. conda create -n ryzen101 python=3.9 | |
# 2. install Ryzen AI Software following https://ryzenai.docs.amd.com/en/latest/manual_installation.html | |
# 3. Run .\transformers\setup.bat | |
# 4. Run .\transformers\opt-onnx\setup.bat recommended in the README can not be run - the file does not exist. | |
# 5. Run .\set_opt_onnx_env.bat opt-125m | |
# 6. Run .\prepare_model.bat opt-125m | |
# 7. And then run: | |
import onnxruntime | |
import numpy as np | |
import os | |
model_path = r"C:\Users\Felix\HF\RyzenAI-SW\example\transformers\opt-onnx\opt-125m_ortquantized\model_quantized.onnx" | |
providers = ["VitisAIExecutionProvider"] | |
vaip_config_path = r"C:\Users\Felix\HF\RyzenAI-SW\example\transformers\opt-onnx\vaip_config.json" | |
cache_dir = os.path.join(os.path.dirname(model_path), "vaie_cache") | |
provider_options = [{ | |
'config_file': vaip_config_path, | |
'cacheDir': str(cache_dir), | |
}] | |
print("----- LOAD MODEL") | |
session = onnxruntime.InferenceSession( | |
model_path, | |
providers=providers, | |
provider_options=provider_options | |
) | |
inp = { | |
"input_ids": np.random.randint(0, 5, size=(2, 5)).astype(np.int64), | |
"attention_mask": np.ones((2, 5)).astype(np.int64), | |
} | |
for i in range(12): | |
dummy_key = np.random.rand(2, 12, 0, 64).astype(np.float32) | |
inp[f"past_key_values.{i}.key"] = dummy_key | |
inp[f"past_key_values.{i}.value"] = dummy_key | |
print("----- RUN INFERENCE") | |
outputs = session.run(None, inp) | |
print("outputs:", len(outputs)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment