Created
January 20, 2025 18:44
-
-
Save pashu123/739a1a142b717cacea4538a2baeceedf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import onnx | |
import onnx.helper as helper | |
# Define the model's input and output | |
batch_size = 1 | |
seq_len = 512 | |
hidden_size = 4096 | |
kv_seq_len = 4 | |
kv_hidden_size = 16 | |
num_heads = 32 # Example number of attention heads | |
head_size = hidden_size // num_heads | |
# Inputs | |
query = helper.make_tensor_value_info("query", onnx.TensorProto.FLOAT, [batch_size, seq_len, hidden_size]) | |
key = helper.make_tensor_value_info("key", onnx.TensorProto.FLOAT, [batch_size, seq_len, hidden_size]) | |
value = helper.make_tensor_value_info("value", onnx.TensorProto.FLOAT, [batch_size, seq_len, hidden_size]) | |
past_key = helper.make_tensor_value_info("past_key", onnx.TensorProto.FLOAT, [batch_size, num_heads, seq_len, head_size]) # 4D tensor | |
past_value = helper.make_tensor_value_info("past_value", onnx.TensorProto.FLOAT, [batch_size, num_heads, seq_len, head_size]) # 4D tensor | |
seqlens_k = helper.make_tensor_value_info("seqlens_k", onnx.TensorProto.INT32, [batch_size]) | |
total_sequence_length = helper.make_tensor_value_info("total_sequence_length", onnx.TensorProto.INT32, []) | |
# Outputs | |
output = helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [batch_size, seq_len, hidden_size]) | |
present_key = helper.make_tensor_value_info("present_key", onnx.TensorProto.FLOAT, [batch_size, num_heads, seq_len, head_size]) | |
present_value = helper.make_tensor_value_info("present_value", onnx.TensorProto.FLOAT, [batch_size, num_heads, seq_len, head_size]) | |
# Create the GroupQueryAttention node with updated attributes | |
group_query_attention_node = helper.make_node( | |
"GroupQueryAttention", | |
inputs=[ | |
"query", "key", "value", "past_key", "past_value", "seqlens_k", "total_sequence_length", | |
], | |
outputs=["output", "present_key", "present_value"], | |
domain="com.microsoft", | |
do_rotary=0, # Enable rotary position embedding | |
kv_num_heads=num_heads, # Number of attention heads for key and value | |
local_window_size=-1, # Use -1 to indicate no local attention | |
num_heads=num_heads, # Number of attention heads for query | |
rotary_interleaved=0, # Disable interleaved rotary (set 1 to enable) | |
scale=1.0, # Default scale | |
smooth_softmax=0, # Smooth softmax disabled | |
softcap=0.0, # No softcap | |
) | |
# Create the graph | |
graph = helper.make_graph( | |
nodes=[group_query_attention_node], | |
name="GroupQueryAttentionGraph", | |
inputs=[ | |
query, key, value, past_key, past_value, seqlens_k, total_sequence_length ], | |
outputs=[output, present_key, present_value], | |
) | |
# Create the model | |
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("com.microsoft", 1)]) | |
model.ir_version = 7 # Specify the IR version | |
# Save the model to an ONNX file | |
onnx.save(model, "group_query_attention.onnx") | |
print("ONNX model with GroupQueryAttention saved as 'group_query_attention.onnx'.") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import onnxruntime as ort | |
from onnxruntime_extensions import get_library_path | |
import time | |
def benchmark_inference(session, inputs, num_iterations=100): | |
""" | |
Benchmark the speed and time of the inference process. | |
Args: | |
session: The ONNX runtime session for running inference. | |
inputs: A dictionary containing input data for the model. | |
num_iterations: The number of iterations to run for benchmarking. | |
Returns: | |
A dictionary containing speed (in iterations/second) and latency (in milliseconds). | |
""" | |
# Warm-up runs | |
for _ in range(5): | |
session.run(None, input_feed=inputs) | |
# Measure the time for multiple iterations | |
start_time = time.time() | |
for _ in range(num_iterations): | |
session.run(None, input_feed=inputs) | |
end_time = time.time() | |
# Calculate metrics | |
total_time = end_time - start_time | |
latency_ms = (total_time / num_iterations) * 1000 # Time per inference in ms | |
speed_ips = num_iterations / total_time # Inferences per second | |
return { | |
"speed (it/s)": speed_ips, | |
"latency (ms)": latency_ms | |
} | |
# Example usage | |
session_options = ort.SessionOptions() | |
session_options.register_custom_ops_library(get_library_path()) | |
# Load the ONNX model | |
model_path = "group_query_attention.onnx" | |
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"], session_options=session_options) | |
# Prepare mock inputs | |
batch_size = 1 | |
seq_len = 512 | |
hidden_size = 4096 | |
kv_seq_len = 4 | |
kv_hidden_size = 16 | |
num_heads = 32 # Example number of attention heads | |
head_size = hidden_size // num_heads | |
# Input tensors | |
query = np.random.randn(batch_size, seq_len, hidden_size).astype(np.float32) | |
key = np.random.randn(batch_size, seq_len, hidden_size).astype(np.float32) | |
value = np.random.randn(batch_size, seq_len, hidden_size).astype(np.float32) | |
# Adjust past_key and past_value to be 4-dimensional | |
past_key = np.random.randn(batch_size, num_heads, seq_len, head_size).astype(np.float32) # 4D | |
past_value = np.random.randn(batch_size, num_heads, seq_len, head_size).astype(np.float32) # 4D | |
# Other inputs | |
seqlens_k = np.array([seq_len - 1], dtype=np.int32) # Sequence length for key | |
total_sequence_length = np.array([seq_len+1], dtype=np.int32) # Total sequence length | |
cos_cache = np.random.randn(seq_len, hidden_size // 2).astype(np.float32) # Cosine cache for rotary | |
sin_cache = np.random.randn(seq_len, hidden_size // 2).astype(np.float32) # Sine cache for rotary | |
# Run inference | |
inputs = { | |
"query": query, | |
"key": key, | |
"value": value, | |
"past_key": past_key, | |
"past_value": past_value, | |
"seqlens_k": seqlens_k, | |
"total_sequence_length": total_sequence_length, | |
} | |
outputs = session.run(None, input_feed=inputs) | |
results = benchmark_inference(session, inputs, num_iterations=100) | |
print("Benchmark Results:") | |
print(f"Speed: {results['speed (it/s)']:.2f} it/s") | |
print(f"Latency: {results['latency (ms)']:.2f} ms") | |
# Display output shapes | |
print("Output Shape:", outputs[0].shape) # The output tensor (batch_size, seq_len, hidden_size) | |
print("Present Key Shape:", outputs[1].shape) # Present key (batch_size, seq_len, hidden_size) | |
print("Present Value Shape:", outputs[2].shape) # Present value (batch_size, seq_len, hidden_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment