Skip to content

Instantly share code, notes, and snippets.

@HDCharles
HDCharles / gist:3deb02e26e3d53f106c15216ff0e608a
Last active July 25, 2023 20:09
benchmarks for mixed_precision triton kernel
M N K Abf16Wint8 cublas_bf16_linear ratio (higher is good)
(5120, 1280, 1280) 0.1783519983291626 0.09504000097513199 0.5328788119308213
(5120, 1280, 3840) 0.41705599427223206 0.2657119929790497 0.6371134730786461
(5120, 1280, 5120) 0.548255980014801 0.39129599928855896 0.7137104081892464
(5120, 5120, 1280) 0.6700800061225891 0.3110080063343048 0.46413563080914677
(5120, 5120, 3840) 1.6803359985351562 0.9334560036659241 0.5555174706009217
(5120, 5120, 5120) 2.511807918548584 1.164560079574585 0.46363420983540454
(16384, 1280, 1280) 0.48159998655319214 0.28620800375938416 0.5942857386848636
(16384, 1280, 3840) 1.2650560140609741 0.7705600261688232 0.6091113892223932
(16384, 1280, 5120) 1.7240959405899048 1.2534880638122559 0.727040783695234
@HDCharles
HDCharles / gist:a67453ca18e2462102dec8a16c83ed1f
Created July 27, 2023 07:22
triton mixed matmul kernels
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of `triton.Config` objects that define different configurations of
# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
@HDCharles
HDCharles / gist:9f783e7ae3531127e8a2233760b52a65
Last active July 27, 2023 08:19
mini benchmark for triton matmul kernels
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import torch
import triton
import triton.language as tl
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from triton.ops.matmul import matmul as triton_matmul
from triton.ops.matmul import _kernel
from triton import Config
def get_configs_io_bound():
configs = []
@HDCharles
HDCharles / gist:00562275b4e360e2784a058c283d7f73
Created July 27, 2023 19:34
int8_weight_only_linear_kernel
//
// Generated by LLVM NVPTX Back-End
//
.version 8.1
.target sm_80
.address_size 64
// .globl int8_weight_only_linear_kernel_0d1d2d3d4d5d6d7d8d9c10c11d12c13d14c
.extern .shared .align 1 .b8 global_smem[];
//
// Generated by LLVM NVPTX Back-End
//
.version 8.1
.target sm_80
.address_size 64
// .globl int8_weight_only_linear_kernel_0d1d2d3d4d5d6d7d8d9c10c11d12c13d14c
.extern .shared .align 1 .b8 global_smem[];
@HDCharles
HDCharles / gist:1c2109774d1d83194428338b08711219
Created July 28, 2023 16:33
triton mixed dtype matmul kernels
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from triton.ops.matmul import matmul as triton_matmul
from triton.ops.matmul import _kernel
from triton import Config
import nvtx
import time
@HDCharles
HDCharles / benchmark.py
Created August 1, 2023 17:33
benchmarking mixed dtype matmul's
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from triton.ops.matmul import matmul as triton_matmul
from triton.ops.matmul import _kernel
from triton import Config
import nvtx
import time
@HDCharles
HDCharles / gist:b2d8c916cfc4629d3f81f09de734e577
Created August 14, 2023 16:21
microbenchmarks for mixed dtype kernels
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from triton.ops.matmul import matmul as triton_matmul
from triton.ops.matmul import _kernel
from triton import Config
import nvtx
import time
@HDCharles
HDCharles / gist:44952fc614a75ad083f5054d50ef5341
Created September 19, 2023 23:56
not using block pointers
@triton.jit
def matmul_kernel_with_block_pointers(
# Pointers to matrices
a_ptr, b_ptr, c_ptr, s1_ptr, s2_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak,