Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created January 20, 2025 18:44
Show Gist options
  • Save pashu123/739a1a142b717cacea4538a2baeceedf to your computer and use it in GitHub Desktop.
Save pashu123/739a1a142b717cacea4538a2baeceedf to your computer and use it in GitHub Desktop.
import onnx
import onnx.helper as helper
# Define the model's input and output
batch_size = 1
seq_len = 512
hidden_size = 4096
kv_seq_len = 4
kv_hidden_size = 16
num_heads = 32 # Example number of attention heads
head_size = hidden_size // num_heads
# Inputs
query = helper.make_tensor_value_info("query", onnx.TensorProto.FLOAT, [batch_size, seq_len, hidden_size])
key = helper.make_tensor_value_info("key", onnx.TensorProto.FLOAT, [batch_size, seq_len, hidden_size])
value = helper.make_tensor_value_info("value", onnx.TensorProto.FLOAT, [batch_size, seq_len, hidden_size])
past_key = helper.make_tensor_value_info("past_key", onnx.TensorProto.FLOAT, [batch_size, num_heads, seq_len, head_size]) # 4D tensor
past_value = helper.make_tensor_value_info("past_value", onnx.TensorProto.FLOAT, [batch_size, num_heads, seq_len, head_size]) # 4D tensor
seqlens_k = helper.make_tensor_value_info("seqlens_k", onnx.TensorProto.INT32, [batch_size])
total_sequence_length = helper.make_tensor_value_info("total_sequence_length", onnx.TensorProto.INT32, [])
# Outputs
output = helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [batch_size, seq_len, hidden_size])
present_key = helper.make_tensor_value_info("present_key", onnx.TensorProto.FLOAT, [batch_size, num_heads, seq_len, head_size])
present_value = helper.make_tensor_value_info("present_value", onnx.TensorProto.FLOAT, [batch_size, num_heads, seq_len, head_size])
# Create the GroupQueryAttention node with updated attributes
group_query_attention_node = helper.make_node(
"GroupQueryAttention",
inputs=[
"query", "key", "value", "past_key", "past_value", "seqlens_k", "total_sequence_length",
],
outputs=["output", "present_key", "present_value"],
domain="com.microsoft",
do_rotary=0, # Enable rotary position embedding
kv_num_heads=num_heads, # Number of attention heads for key and value
local_window_size=-1, # Use -1 to indicate no local attention
num_heads=num_heads, # Number of attention heads for query
rotary_interleaved=0, # Disable interleaved rotary (set 1 to enable)
scale=1.0, # Default scale
smooth_softmax=0, # Smooth softmax disabled
softcap=0.0, # No softcap
)
# Create the graph
graph = helper.make_graph(
nodes=[group_query_attention_node],
name="GroupQueryAttentionGraph",
inputs=[
query, key, value, past_key, past_value, seqlens_k, total_sequence_length ],
outputs=[output, present_key, present_value],
)
# Create the model
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("com.microsoft", 1)])
model.ir_version = 7 # Specify the IR version
# Save the model to an ONNX file
onnx.save(model, "group_query_attention.onnx")
print("ONNX model with GroupQueryAttention saved as 'group_query_attention.onnx'.")
import numpy as np
import onnxruntime as ort
from onnxruntime_extensions import get_library_path
import time
def benchmark_inference(session, inputs, num_iterations=100):
"""
Benchmark the speed and time of the inference process.
Args:
session: The ONNX runtime session for running inference.
inputs: A dictionary containing input data for the model.
num_iterations: The number of iterations to run for benchmarking.
Returns:
A dictionary containing speed (in iterations/second) and latency (in milliseconds).
"""
# Warm-up runs
for _ in range(5):
session.run(None, input_feed=inputs)
# Measure the time for multiple iterations
start_time = time.time()
for _ in range(num_iterations):
session.run(None, input_feed=inputs)
end_time = time.time()
# Calculate metrics
total_time = end_time - start_time
latency_ms = (total_time / num_iterations) * 1000 # Time per inference in ms
speed_ips = num_iterations / total_time # Inferences per second
return {
"speed (it/s)": speed_ips,
"latency (ms)": latency_ms
}
# Example usage
session_options = ort.SessionOptions()
session_options.register_custom_ops_library(get_library_path())
# Load the ONNX model
model_path = "group_query_attention.onnx"
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"], session_options=session_options)
# Prepare mock inputs
batch_size = 1
seq_len = 512
hidden_size = 4096
kv_seq_len = 4
kv_hidden_size = 16
num_heads = 32 # Example number of attention heads
head_size = hidden_size // num_heads
# Input tensors
query = np.random.randn(batch_size, seq_len, hidden_size).astype(np.float32)
key = np.random.randn(batch_size, seq_len, hidden_size).astype(np.float32)
value = np.random.randn(batch_size, seq_len, hidden_size).astype(np.float32)
# Adjust past_key and past_value to be 4-dimensional
past_key = np.random.randn(batch_size, num_heads, seq_len, head_size).astype(np.float32) # 4D
past_value = np.random.randn(batch_size, num_heads, seq_len, head_size).astype(np.float32) # 4D
# Other inputs
seqlens_k = np.array([seq_len - 1], dtype=np.int32) # Sequence length for key
total_sequence_length = np.array([seq_len+1], dtype=np.int32) # Total sequence length
cos_cache = np.random.randn(seq_len, hidden_size // 2).astype(np.float32) # Cosine cache for rotary
sin_cache = np.random.randn(seq_len, hidden_size // 2).astype(np.float32) # Sine cache for rotary
# Run inference
inputs = {
"query": query,
"key": key,
"value": value,
"past_key": past_key,
"past_value": past_value,
"seqlens_k": seqlens_k,
"total_sequence_length": total_sequence_length,
}
outputs = session.run(None, input_feed=inputs)
results = benchmark_inference(session, inputs, num_iterations=100)
print("Benchmark Results:")
print(f"Speed: {results['speed (it/s)']:.2f} it/s")
print(f"Latency: {results['latency (ms)']:.2f} ms")
# Display output shapes
print("Output Shape:", outputs[0].shape) # The output tensor (batch_size, seq_len, hidden_size)
print("Present Key Shape:", outputs[1].shape) # Present key (batch_size, seq_len, hidden_size)
print("Present Value Shape:", outputs[2].shape) # Present value (batch_size, seq_len, hidden_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment