Skip to content

Instantly share code, notes, and snippets.

@tiandiao123
Last active July 4, 2023 08:01
Show Gist options
  • Save tiandiao123/809e1adc7ad66e0c7dd0136d12062a1f to your computer and use it in GitHub Desktop.
Save tiandiao123/809e1adc7ad66e0c7dd0136d12062a1f to your computer and use it in GitHub Desktop.
import os
import torch
import numpy as np
from deepspeed.ops.transformer.inference.triton.attention import compute_attention as deepspeed_compute_attention
from inference.ops.self_attention import self_attention_compute_using_triton
def run_func(func, qkv):
func(qkv,
alibi = None,
head_size = 32,
scale = 1.2,
input_mask = None,
layer_past = None,
triangular = True)
def benchmark_inference(func, qkv):
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
repetitions = 300
timings=np.zeros((repetitions,1))
#GPU-WARM-UP
for _ in range(10):
_ = run_func(func, qkv)
# MEASURE PERFORMANCE
with torch.no_grad():
for rep in range(repetitions):
new_qkv = qkv.clone()
starter.record()
_ = run_func(func, new_qkv)
ender.record()
# WAIT FOR GPU SYNC
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender)
timings[rep] = curr_time
mean_syn = np.sum(timings) / repetitions
return mean_syn
def test():
print("working on self-attention layer using triton")
qkv = torch.randn((4, 24, 256*3), device="cuda")
new_qkv = qkv.clone()
data_output_triton = self_attention_compute_using_triton(new_qkv,
alibi = None,
head_size = 32,
scale = 1.2,
input_mask = None,
layer_past = None,
triangular = True,
use_flash = False)
latency_1 = benchmark_inference(self_attention_compute_using_triton, qkv.clone())
print("working on deepspeed inference")
new_qkv = qkv.clone()
data_output_deepspeed = deepspeed_compute_attention(new_qkv,
alibi = None,
head_size = 32,
scale = 1.2,
input_mask = None,
layer_past = None,
triangular = True)
latency_2 = benchmark_inference(deepspeed_compute_attention, qkv.clone())
print("check correctness of deepspeed op and triton implemented op")
print(torch.allclose(data_output_deepspeed.cpu(), data_output_triton.cpu(), rtol=1e-3, atol=1e-3))
print("testing inference speed of triton implementaion without using kernel fusion")
print("use triton implementaion is: ")
print(str(latency_1) + " ms")
print("use deepspeed implementation is: ")
print(str(latency_2) + " ms")
if __name__ == "__main__":
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment