Last active
July 4, 2023 08:01
-
-
Save tiandiao123/809e1adc7ad66e0c7dd0136d12062a1f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import numpy as np | |
from deepspeed.ops.transformer.inference.triton.attention import compute_attention as deepspeed_compute_attention | |
from inference.ops.self_attention import self_attention_compute_using_triton | |
def run_func(func, qkv): | |
func(qkv, | |
alibi = None, | |
head_size = 32, | |
scale = 1.2, | |
input_mask = None, | |
layer_past = None, | |
triangular = True) | |
def benchmark_inference(func, qkv): | |
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | |
repetitions = 300 | |
timings=np.zeros((repetitions,1)) | |
#GPU-WARM-UP | |
for _ in range(10): | |
_ = run_func(func, qkv) | |
# MEASURE PERFORMANCE | |
with torch.no_grad(): | |
for rep in range(repetitions): | |
new_qkv = qkv.clone() | |
starter.record() | |
_ = run_func(func, new_qkv) | |
ender.record() | |
# WAIT FOR GPU SYNC | |
torch.cuda.synchronize() | |
curr_time = starter.elapsed_time(ender) | |
timings[rep] = curr_time | |
mean_syn = np.sum(timings) / repetitions | |
return mean_syn | |
def test(): | |
print("working on self-attention layer using triton") | |
qkv = torch.randn((4, 24, 256*3), device="cuda") | |
new_qkv = qkv.clone() | |
data_output_triton = self_attention_compute_using_triton(new_qkv, | |
alibi = None, | |
head_size = 32, | |
scale = 1.2, | |
input_mask = None, | |
layer_past = None, | |
triangular = True, | |
use_flash = False) | |
latency_1 = benchmark_inference(self_attention_compute_using_triton, qkv.clone()) | |
print("working on deepspeed inference") | |
new_qkv = qkv.clone() | |
data_output_deepspeed = deepspeed_compute_attention(new_qkv, | |
alibi = None, | |
head_size = 32, | |
scale = 1.2, | |
input_mask = None, | |
layer_past = None, | |
triangular = True) | |
latency_2 = benchmark_inference(deepspeed_compute_attention, qkv.clone()) | |
print("check correctness of deepspeed op and triton implemented op") | |
print(torch.allclose(data_output_deepspeed.cpu(), data_output_triton.cpu(), rtol=1e-3, atol=1e-3)) | |
print("testing inference speed of triton implementaion without using kernel fusion") | |
print("use triton implementaion is: ") | |
print(str(latency_1) + " ms") | |
print("use deepspeed implementation is: ") | |
print(str(latency_2) + " ms") | |
if __name__ == "__main__": | |
test() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment