Created
July 28, 2023 16:33
-
-
Save HDCharles/1c2109774d1d83194428338b08711219 to your computer and use it in GitHub Desktop.
triton mixed dtype matmul kernels
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn.functional as F | |
| import triton | |
| import triton.language as tl | |
| from triton.ops.matmul import matmul as triton_matmul | |
| from triton.ops.matmul import _kernel | |
| from triton import Config | |
| import nvtx | |
| import time | |
| def get_configs_io_bound(): | |
| configs = [] | |
| for num_stages in [2, 3, 4, 5, 6]: | |
| for block_m in [16, 32]: | |
| for block_k in [32, 64]: | |
| for block_n in [32, 64, 128, 256]: | |
| num_warps = 2 if block_n <= 64 else 4 | |
| configs.append( | |
| Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': 8}, | |
| num_stages=num_stages, num_warps=num_warps)) | |
| return configs | |
| config_list = [ | |
| Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), | |
| Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), | |
| Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), | |
| # good for int8 | |
| Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), | |
| Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), | |
| Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), | |
| ]+get_configs_io_bound() | |
| # `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: | |
| # - A list of `triton.Config` objects that define different configurations of | |
| # meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try | |
| # - An auto-tuning *key* whose change in values will trigger evaluation of all the | |
| # provided configs | |
| @triton.autotune( | |
| configs = config_list, | |
| # configs=[ | |
| # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), | |
| # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), | |
| # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), | |
| # ], | |
| key=['M', 'N', 'K'], | |
| ) | |
| @triton.jit | |
| def matmul_kernel( | |
| # Pointers to matrices | |
| a_ptr, b_ptr, bias_ptr, scale_ptr, c_ptr, | |
| # Matrix dimensions | |
| M, N, K, | |
| # The stride variables represent how much to increase the ptr by when moving by 1 | |
| # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` | |
| # by to get the element one row down (A has M rows). | |
| stride_am, stride_ak, | |
| stride_bk, stride_bn, | |
| stride_bias, | |
| stride_cm, stride_cn, | |
| # Meta-parameters | |
| BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, | |
| GROUP_SIZE_M: tl.constexpr, | |
| ): | |
| """Kernel for computing the matmul C = A x B. | |
| A has shape (M, K), B has shape (K, N) and C has shape (M, N) | |
| """ | |
| # ----------------------------------------------------------- | |
| # Map program ids `pid` to the block of C it should compute. | |
| # This is done in a grouped ordering to promote L2 data reuse. | |
| # See above `L2 Cache Optimizations` section for details. | |
| pid = tl.program_id(axis=0) | |
| num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | |
| num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | |
| num_pid_in_group = GROUP_SIZE_M * num_pid_n | |
| group_id = pid // num_pid_in_group | |
| first_pid_m = group_id * GROUP_SIZE_M | |
| group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | |
| pid_m = first_pid_m + (pid % group_size_m) | |
| pid_n = (pid % num_pid_in_group) // group_size_m | |
| # ---------------------------------------------------------- | |
| # Create pointers for the first blocks of A and B. | |
| # We will advance this pointer as we move in the K direction | |
| # and accumulate | |
| # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers | |
| # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers | |
| # See above `Pointer Arithmetics` section for details | |
| offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | |
| offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | |
| offs_k = tl.arange(0, BLOCK_SIZE_K) | |
| a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) | |
| b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) | |
| bias_ptrs = bias_ptr + (offs_bn * stride_bias) | |
| # ----------------------------------------------------------- | |
| # Iterate to compute a block of the C matrix. | |
| # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block | |
| # of fp32 values for higher accuracy. | |
| # `accumulator` will be converted back to fp16 after the loop. | |
| accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | |
| for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | |
| # Load the next block of A and B, generate a mask by checking the K dimension. | |
| # If it is out of bounds, set it to 0. | |
| a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) | |
| b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) | |
| # We accumulate along the K dimension. | |
| accumulator += tl.dot(a, b.to(tl.bfloat16)) | |
| # Advance the ptrs to the next K block. | |
| a_ptrs += BLOCK_SIZE_K * stride_ak | |
| b_ptrs += BLOCK_SIZE_K * stride_bk | |
| scale = tl.load(scale_ptr) | |
| bias = tl.load(bias_ptrs) | |
| c = accumulator.to(tl.bfloat16) * scale + bias | |
| # ----------------------------------------------------------- | |
| # Write back the block of the output matrix C with masks. | |
| offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
| offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | |
| c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] | |
| c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) | |
| tl.store(c_ptrs, c, mask=c_mask) | |
| def matmul(a, b, bias, scale): | |
| # Check constraints. | |
| assert a.shape[1] == b.shape[0], "Incompatible dimensions" | |
| # assert a.is_contiguous(), "Matrix A must be contiguous" | |
| # assert b.is_contiguous(), "Matrix B must be contiguous" | |
| M, K = a.shape | |
| K, N = b.shape | |
| # Allocates output. | |
| c = torch.empty((M, N), device=a.device, dtype=a.dtype) | |
| # 1D launch kernel where each block gets its own program. | |
| grid = lambda META: ( | |
| triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), | |
| ) | |
| matmul_kernel[grid]( | |
| a, b, bias, scale, c, | |
| M, N, K, | |
| a.stride(0), a.stride(1), | |
| b.stride(0), b.stride(1), | |
| bias.stride(0), | |
| c.stride(0), c.stride(1), | |
| ) | |
| return c | |
| # `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: | |
| # - A list of `triton.Config` objects that define different configurations of | |
| # meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try | |
| # - An auto-tuning *key* whose change in values will trigger evaluation of all the | |
| # provided configs | |
| @triton.autotune( | |
| configs=config_list, | |
| # configs = [ | |
| # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), | |
| # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), | |
| # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), | |
| # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), | |
| # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), | |
| # ] | |
| key=['M', 'N', 'K'], | |
| ) | |
| @triton.jit | |
| def int8_weight_only_linear_kernel( | |
| # Pointers to matrices | |
| x_ptr, w_ptr, b_ptr, s_ptr, y_ptr, | |
| # Matrix dimensions | |
| M, N, K, | |
| # The stride variables represent how much to increase the ptr by when moving by 1 | |
| # element in a particular dimension. E.g. `stride_am` is how much to increase `x_ptr` | |
| # by to get the element one row down (A has M rows). | |
| stride_xm, stride_xk, | |
| stride_wk, stride_wn, | |
| stride_b, | |
| stride_ym, stride_yn, | |
| # Meta-parameters | |
| BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, | |
| GROUP_SIZE_M: tl.constexpr, | |
| ): | |
| """Kernel for computing the matmul C = A x B. | |
| A has shape (M, K), B has shape (K, N) and C has shape (M, N) | |
| """ | |
| # ----------------------------------------------------------- | |
| # Map program ids `pid` to the block of Y it should compute. | |
| # This is done in a grouped ordering to promote L2 data reuse. | |
| # See above `L2 Cache Optimizations` section for details. | |
| pid = tl.program_id(axis=0) | |
| num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | |
| num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | |
| num_pid_in_group = GROUP_SIZE_M * num_pid_n | |
| group_id = pid // num_pid_in_group | |
| first_pid_m = group_id * GROUP_SIZE_M | |
| group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | |
| pid_m = first_pid_m + (pid % group_size_m) | |
| pid_n = (pid % num_pid_in_group) // group_size_m | |
| # ---------------------------------------------------------- | |
| # Create pointers for the first blocks of X and W. | |
| # We will advance this pointer as we move in the K direction | |
| # and accumulate | |
| # `x_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers | |
| # `w_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers | |
| # See above `Pointer Arithmetics` section for details | |
| offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | |
| offs_wn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | |
| offs_k = tl.arange(0, BLOCK_SIZE_K) | |
| x_ptrs = x_ptr + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) | |
| w_ptrs = w_ptr + (offs_k[:, None] * stride_wk + offs_wn[None, :] * stride_wn) | |
| b_ptrs = b_ptr + (offs_wn * stride_b) | |
| step_w = BLOCK_SIZE_K * stride_wk | |
| step_x = BLOCK_SIZE_K * stride_xk | |
| # ----------------------------------------------------------- | |
| # Iterate to compute a block of the Y matrix. | |
| # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block | |
| # of fp32 values for higher accuracy. | |
| # `accumulator` will be converted back to fp16 after the loop. | |
| accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | |
| for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | |
| # Load the next block of A and B, generate a mask by checking the K dimension. | |
| # If it is out of bounds, set it to 0. | |
| x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) | |
| w = tl.load(w_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) | |
| # We accumulate along the K dimension. | |
| accumulator += tl.dot(x, w.to(tl.bfloat16)) | |
| # Advance the ptrs to the next K block. | |
| x_ptrs += step_x | |
| w_ptrs += step_w | |
| s = tl.load(s_ptr) | |
| b = tl.load(b_ptrs) | |
| y = (accumulator.to(tl.bfloat16) * s + b) | |
| # y = accumulator | |
| # ----------------------------------------------------------- | |
| # Write back the block of the output matrix Y with masks. | |
| offs_ym = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
| offs_yn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | |
| y_ptrs = y_ptr + stride_ym * offs_ym[:, None] + stride_yn * offs_yn[None, :] | |
| y_mask = (offs_ym[:, None] < M) & (offs_yn[None, :] < N) | |
| tl.store(y_ptrs, y, mask=y_mask) | |
| def int8_weight_only_linear(x, w, b, s): | |
| # Check constraints. | |
| assert x.shape[1] == w.shape[0], "Incompatible dimensions" | |
| # assert x.is_contiguous(), "Matrix x must be contiguous" | |
| # assert w.is_contiguous(), "Matrix w must be contiguous" | |
| M, K = x.shape | |
| K, N = w.shape | |
| assert b.shape[0] == N | |
| # Allocates output. | |
| y = torch.empty((M, N), device=x.device, dtype=x.dtype) | |
| # 1D launch kernel where each block gets its own program. | |
| grid = lambda META: ( | |
| triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), | |
| ) | |
| int8_weight_only_linear_kernel[grid]( | |
| x, w, b, s, y, | |
| M, N, K, | |
| x.stride(0), x.stride(1), | |
| w.stride(0), w.stride(1), | |
| b.stride(0), | |
| y.stride(0), y.stride(1), | |
| ) | |
| return y | |
| # `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: | |
| # - A list of `triton.Config` objects that define different configurations of | |
| # meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try | |
| # - An auto-tuning *key* whose change in values will trigger evaluation of all the | |
| # provided configs | |
| @triton.autotune( | |
| configs=config_list, | |
| # configs = [ | |
| # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), | |
| # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), | |
| # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), | |
| # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
| # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), | |
| # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), | |
| # ], | |
| key=['M', 'N', 'K'], | |
| ) | |
| @triton.jit | |
| def uint4x2_weight_only_linear_kernel( | |
| # Pointers to matrices | |
| x_ptr, w_ptr, b_ptr, s_ptr, y_ptr, | |
| # Matrix dimensions | |
| M, N, K, # x is Mx(K*2) and w is KxN | |
| # The stride variables represent how much to increase the ptr by when moving by 1 | |
| # element in a particular dimension. E.g. `stride_am` is how much to increase `x_ptr` | |
| # by to get the element one row down (A has M rows). | |
| stride_xm, stride_xk, | |
| stride_wk, stride_wn, | |
| stride_b, | |
| stride_ym, stride_yn, | |
| # Meta-parameters | |
| BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, | |
| GROUP_SIZE_M: tl.constexpr, | |
| ): | |
| """Kernel for computing the matmul C = A x B. | |
| A has shape (M, K), B has shape (K, N) and C has shape (M, N) | |
| """ | |
| # ----------------------------------------------------------- | |
| # Map program ids `pid` to the block of Y it should compute. | |
| # This is done in a grouped ordering to promote L2 data reuse. | |
| # See above `L2 Cache Optimizations` section for details. | |
| pid = tl.program_id(axis=0) | |
| num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | |
| num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | |
| num_pid_in_group = GROUP_SIZE_M * num_pid_n | |
| group_id = pid // num_pid_in_group | |
| first_pid_m = group_id * GROUP_SIZE_M | |
| group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | |
| pid_m = first_pid_m + (pid % group_size_m) | |
| pid_n = (pid % num_pid_in_group) // group_size_m | |
| # ---------------------------------------------------------- | |
| # Create pointers for the first blocks of X and W. | |
| # We will advance this pointer as we move in the K direction | |
| # and accumulate | |
| # `x_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers | |
| # `w_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers | |
| # See above `Pointer Arithmetics` section for details | |
| offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | |
| offs_wn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | |
| offs_k = tl.arange(0, BLOCK_SIZE_K) | |
| x_ptrs = x_ptr + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) | |
| w_ptrs = w_ptr + (offs_k[:, None]//2 * stride_wk + offs_wn[None, :] * stride_wn) | |
| w_shifts = (offs_k % 2) * 4 | |
| w_subs = (1-(offs_k % 2)) * 8 | |
| b_ptrs = b_ptr + (offs_wn * stride_b) | |
| step_w = BLOCK_SIZE_K//2 * stride_wk | |
| step_x = BLOCK_SIZE_K * stride_xk | |
| # ----------------------------------------------------------- | |
| # Iterate to compute a block of the Y matrix. | |
| # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block | |
| # of fp32 values for higher accuracy. | |
| # `accumulator` will be converted back to fp16 after the loop. | |
| accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | |
| for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | |
| # Load the next block of A and B, generate a mask by checking the K dimension. | |
| # If it is out of bounds, set it to 0. | |
| x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) | |
| w = tl.load(w_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) | |
| w = ((w >> w_shifts[:, None]) & 0xF) - w_subs[:, None] | |
| # We accumulate along the K dimension. | |
| accumulator += tl.dot(x, w.to(tl.bfloat16)) | |
| # Advance the ptrs to the next K block. | |
| x_ptrs += step_x | |
| w_ptrs += step_w | |
| s = tl.load(s_ptr) | |
| b = tl.load(b_ptrs) | |
| y = (accumulator.to(tl.bfloat16) * s)+b | |
| # ----------------------------------------------------------- | |
| # Write back the block of the output matrix Y with masks. | |
| offs_ym = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
| offs_yn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | |
| y_ptrs = y_ptr + stride_ym * offs_ym[:, None] + stride_yn * offs_yn[None, :] | |
| y_mask = (offs_ym[:, None] < M) & (offs_yn[None, :] < N) | |
| tl.store(y_ptrs, y, mask=y_mask) | |
| def uint4x2_weight_only_linear(x, w, b, s): | |
| # Check constraints. | |
| assert x.shape[1] == w.shape[0]*2, "Incompatible dimensions" | |
| # assert x.is_contiguous(), "Matrix x must be contiguous" | |
| # assert w.is_contiguous(), "Matrix w must be contiguous" | |
| M, K = x.shape | |
| _, N = w.shape | |
| assert b.shape[0] == N | |
| # Allocates output. | |
| y = torch.empty((M, N), device=x.device, dtype=x.dtype) | |
| # 1D launch kernel where each block gets its own program. | |
| grid = lambda META: ( | |
| triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), | |
| ) | |
| uint4x2_weight_only_linear_kernel[grid]( | |
| x, w, b, s, y, | |
| M, N, K, | |
| x.stride(0), x.stride(1), | |
| w.stride(0), w.stride(1), | |
| b.stride(0), | |
| y.stride(0), y.stride(1), | |
| ) | |
| return y | |
| quantiles = [0.5, 0.2, 0.8] # idk what this is for but the tutorial had it | |
| torch.cuda.cudart().cudaProfilerStart() | |
| result = {} | |
| for D in [2**12]: | |
| print(D) | |
| result[D]={} | |
| result[D]["cublas linear"]={} | |
| result[D]["fp16 matmul"]={} | |
| result[D]["stock triton int8 matmul"]={} | |
| result[D]["int8 linear"]={} | |
| result[D]["uint4x2 linear"]={} | |
| for t_x in [0,1]: | |
| for t_w in [1,0]: | |
| x = torch.randn(D,D).to('cuda').to(torch.bfloat16) | |
| w_bf16 = torch.randn(D,D, dtype=torch.bfloat16).cuda() | |
| bias = torch.randn(D, dtype=torch.bfloat16).cuda() | |
| if t_x: | |
| x = x.t() | |
| if t_w: | |
| w_bf16 = w_bf16.t() | |
| torch.nn.functional.linear(x, w_bf16, bias) | |
| torch.cuda.synchronize() | |
| with nvtx.annotate("cublas linear", color="red"): | |
| # result[D]["cublas linear"][(t_x, t_w)] = triton.testing.do_bench(lambda: torch.nn.functional.linear(x, w_bf16, bias), quantiles=quantiles)[0] | |
| start = time.time() | |
| for _ in range(10): | |
| torch.nn.functional.linear(x, w_bf16, bias) | |
| torch.cuda.synchronize() | |
| result[D]["cublas linear"][(t_x, t_w)] = time.time()-start | |
| del w_bf16 | |
| w_fp16 = torch.randn(D,D, dtype=torch.float16).cuda() | |
| scale = torch.randn(D, dtype=torch.bfloat16).cuda() | |
| matmul(x, w_fp16, bias, scale) | |
| torch.cuda.synchronize() | |
| with nvtx.annotate("fp16 matmul", color="green"): | |
| # result[D]["fp16 matmul"][(t_x, t_w)] = triton.testing.do_bench(lambda: matmul(x, w_fp16, bias, scale), quantiles=quantiles)[0] | |
| start = time.time() | |
| for _ in range(10): | |
| matmul(x, w_fp16, bias, scale) | |
| torch.cuda.synchronize() | |
| result[D]["fp16 matmul"][(t_x, t_w)] = time.time()-start | |
| del w_fp16 | |
| w_int8 = torch.randint(-128, 127, (D, D), dtype=torch.int8).cuda() | |
| if t_w: | |
| w_int8 = w_int8.t() | |
| triton_matmul(x, w_int8) | |
| torch.cuda.synchronize() | |
| # result[D]["bfloat16 linear"][(t_x, t_w)] = triton.testing.do_bench(lambda: matmul(x, w_bf16, bias, scale), quantiles=quantiles)[0] | |
| with nvtx.annotate("stock triton int8 matmul", color="blue"): | |
| # result[D]["stock triton int8 matmul"][(t_x, t_w)] = triton.testing.do_bench(lambda: triton_matmul(x, w_int8), quantiles=quantiles)[0] | |
| start = time.time() | |
| for _ in range(10): | |
| triton_matmul(x, w_int8) | |
| torch.cuda.synchronize() | |
| result[D]["stock triton int8 matmul"][(t_x, t_w)] = time.time()-start | |
| int8_weight_only_linear(x, w_int8, bias, scale) | |
| torch.cuda.synchronize() | |
| with nvtx.annotate("int8 linear", color="purple"): | |
| # result[D]["int8 linear"][(t_x, t_w)] = triton.testing.do_bench(lambda: int8_weight_only_linear(x, w_int8, bias, scale), quantiles=quantiles)[0] | |
| start = time.time() | |
| for _ in range(10): | |
| int8_weight_only_linear(x, w_int8, bias, scale) | |
| torch.cuda.synchronize() | |
| result[D]["int8 linear"][(t_x, t_w)] = time.time()-start | |
| del w_int8 | |
| w_uint4x2 = torch.randint(-128, 127, (D//2, D), dtype=torch.int8).cuda() | |
| if t_w: | |
| w_uint4x2 = torch.randint(-128, 127, (D, D//2), dtype=torch.int8).cuda().t() | |
| uint4x2_weight_only_linear(x, w_uint4x2, bias, scale) | |
| torch.cuda.synchronize() | |
| with nvtx.annotate("uint4x2 linear", color="black"): | |
| # result[D]["uint4x2 linear"][(t_x, t_w)] = triton.testing.do_bench(lambda: uint4x2_weight_only_linear(x, w_uint4x2, bias, scale), quantiles=quantiles)[0] | |
| start = time.time() | |
| for _ in range(10): | |
| uint4x2_weight_only_linear(x, w_uint4x2, bias, scale) | |
| torch.cuda.synchronize() | |
| result[D]["uint4x2 linear"][(t_x, t_w)] = time.time()-start | |
| del w_uint4x2 | |
| caches = {"fp16 matmul": matmul_kernel.cache, "stock triton int8 matmul": _kernel.cache, "int8 linear": int8_weight_only_linear_kernel.cache, "uint4x2 linear": uint4x2_weight_only_linear_kernel.cache} | |
| print("| X . W | X . W.t() | X.t() . W | X.t() . W.t() | model | (M, N, K) | config |") | |
| for d in result.keys(): | |
| for name in result[d].keys(): | |
| r = result[d][name] | |
| print(f"| {r[(0,0)]:.4f} | {r[(0,1)]:.4f} | {r[(1,0)]:.4f} | {r[(1,1)]:.4f} | {name} | {(d, d, d)} | {None if name == 'cublas linear' else caches[name][(d, d, d)]} |") | |
| print(" ") | |
| # using triton version triton-nightly 2.1.0.dev20230726014945 | |
| # install: pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly | |
| # using torch compiled from 2.1.0a0+git9c2122d | |
| # using cuda 12.1, cudnn 8.9.2 on A100 GPU | |
| # ---------- OUTPUT -------------- | |
| # | X . W | X . W.t() | X.t() . W | X.t() . W.t() | model | (M, N, K) | config | | |
| # | 0.0236 | 0.0246 | 0.0246 | 0.0236 | cublas linear | (1024, 1024, 1024) | None | | |
| # | 0.0358 | 0.0358 | 0.0369 | 0.0379 | fp16 matmul | (1024, 1024, 1024) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 | | |
| # | 0.0563 | 0.0399 | 0.0563 | 0.0573 | stock triton int8 matmul | (1024, 1024, 1024) | BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, SPLIT_K: 1, num_warps: 4, num_stages: 4 | | |
| # | 0.0389 | 0.0348 | 0.0410 | 0.0379 | int8 linear | (1024, 1024, 1024) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 3 | | |
| # | 0.0440 | 0.0399 | 0.0440 | 0.0399 | uint4x2 linear | (1024, 1024, 1024) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 | | |
| # | 0.5970 | 0.5908 | 0.5939 | 0.5888 | cublas linear | (4096, 4096, 4096) | None | | |
| # | 0.9984 | 1.0056 | 0.9820 | 0.9800 | fp16 matmul | (4096, 4096, 4096) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 | | |
| # | 1.0803 | 0.9421 | 1.0547 | 1.0230 | stock triton int8 matmul | (4096, 4096, 4096) | BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, SPLIT_K: 1, num_warps: 4, num_stages: 4 | | |
| # | 0.9841 | 0.9472 | 1.0476 | 0.9912 | int8 linear | (4096, 4096, 4096) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 | | |
| # | 0.9984 | 1.0629 | 1.0844 | 0.9759 | uint4x2 linear | (4096, 4096, 4096) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 | | |
| # | 36.6080 | 36.0653 | 36.1764 | 35.7699 | cublas linear | (16384, 16384, 16384) | None | | |
| # | 62.4271 | 63.5976 | 60.7140 | 61.0499 | fp16 matmul | (16384, 16384, 16384) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 | | |
| # | 64.6697 | 59.0377 | 63.2453 | 63.2904 | stock triton int8 matmul | (16384, 16384, 16384) | BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, SPLIT_K: 1, num_warps: 4, num_stages: 4 | | |
| # | 58.8196 | 59.9378 | 73.6072 | 60.7099 | int8 linear | (16384, 16384, 16384) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 | | |
| # | 59.3039 | 69.9720 | 63.4337 | 60.3843 | uint4x2 linear | (16384, 16384, 16384) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment