Created
December 20, 2021 14:31
-
-
Save romain-keramitas-prl/b9a1dbb313ce8cb1d53a188b170ed0cc to your computer and use it in GitHub Desktop.
Code to reproduce ORT error for cross-attention with dynamic past key / values axis
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import numpy as np | |
from onnx import TensorProto, helper, save | |
from onnxruntime import InferenceSession | |
hidden_size = 10 | |
num_heads = 2 | |
head_size = 5 | |
# Create node | |
attention_input_names = [ | |
"query", | |
"key", | |
"q_weight", | |
"kv_weight", | |
"qkv_bias", | |
"key_mask", | |
"past_key", | |
"past_value", | |
"is_cross", | |
"use_past", | |
"use_layer", | |
"has_mask", | |
] | |
attention_output_names = ["hidden_state", "present_key", "present_value"] | |
attention_node = helper.make_node( | |
"DecoderAttention", | |
inputs=attention_input_names, | |
outputs=attention_output_names, | |
name="cross_attention_node", | |
) | |
attention_node.domain = "com.microsoft" | |
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) | |
# Create initializers | |
tensors = [] | |
q_weight = np.random.randn(hidden_size, hidden_size) | |
kv_weight = np.random.randn(hidden_size, hidden_size * 2) | |
qkv_bias = np.random.randn(hidden_size * 3) | |
for x, name in zip( | |
[q_weight, kv_weight, qkv_bias], ["q_weight", "kv_weight", "qkv_bias"] | |
): | |
tensor_proto = helper.make_tensor( | |
name=name, data_type=TensorProto.FLOAT, dims=x.shape, vals=x.flatten().tolist() | |
) | |
tensors.append(tensor_proto) | |
for name in [ | |
"is_cross", | |
"use_past", | |
"use_layer", | |
"has_mask", | |
]: | |
tensor_proto = helper.make_tensor( | |
name=name, data_type=TensorProto.BOOL, dims=(), vals=[True] | |
) | |
tensors.append(tensor_proto) | |
# Create inputs with either batch size / sequence length fixed or not | |
past_kv_shape_1 = ["batch_size", num_heads, "seq_length", head_size] | |
past_kv_shape_2 = [1, num_heads, 1, head_size] | |
common_inputs = [ | |
helper.make_tensor_value_info( | |
"query", TensorProto.FLOAT, shape=["seq_length", "batch_size", hidden_size] | |
), | |
helper.make_tensor_value_info( | |
"key", TensorProto.FLOAT, shape=["seq_length", "batch_size", hidden_size] | |
), | |
helper.make_tensor_value_info( | |
"key_mask", TensorProto.BOOL, shape=["batch_size", "seq_length"] | |
), | |
] | |
inputs_1 = common_inputs + [ | |
helper.make_tensor_value_info( | |
"past_key", | |
TensorProto.FLOAT, | |
shape=past_kv_shape_1, | |
), | |
helper.make_tensor_value_info( | |
"past_value", | |
TensorProto.FLOAT, | |
shape=past_kv_shape_1, | |
), | |
] | |
inputs_2 = common_inputs + [ | |
helper.make_tensor_value_info( | |
"past_key", | |
TensorProto.FLOAT, | |
shape=past_kv_shape_2, | |
), | |
helper.make_tensor_value_info( | |
"past_value", | |
TensorProto.FLOAT, | |
shape=past_kv_shape_2, | |
), | |
] | |
# Create outputs | |
outputs = [ | |
helper.make_tensor_value_info( | |
"hidden_state", | |
TensorProto.FLOAT, | |
shape=["seq_length", "batch_size", hidden_size], | |
), | |
helper.make_tensor_value_info( | |
"present_key", | |
TensorProto.FLOAT, | |
shape=["batch_size", num_heads, "seq_length", head_size], | |
), | |
helper.make_tensor_value_info( | |
"present_value", | |
TensorProto.FLOAT, | |
shape=["batch_size", num_heads, "seq_length", head_size], | |
), | |
] | |
# create and save models | |
for inputs, name in zip( | |
[inputs_1, inputs_2], ["test_dynamic_inputs", "test_fixed_inputs"] | |
): | |
graph = helper.make_graph( | |
[attention_node], | |
initializer=tensors, | |
name=name, | |
inputs=inputs, | |
outputs=outputs, | |
) | |
model = helper.make_model(graph) | |
save(model, f"{name}.onnx") | |
# test models | |
for name in ["test_dynamic_inputs", "test_fixed_inputs"]: | |
try: | |
InferenceSession(f"{name}.onnx", providers=["CUDAExecutionProvider"]) | |
print(f"{name}: success") | |
except Exception as e: | |
print(f"{name}: failed") | |
print(f"message: {e}") | |
os.remove(f"{name}.onnx") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment