Created
August 31, 2023 22:19
-
-
Save amjames/707d22682760dfbf5f1169e5f7f738cf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
import triton | |
import triton.language as tl | |
from typing import Optional, Tuple | |
def _has_triton(): | |
if not torch.cuda.is_available(): | |
return False | |
try: | |
import triton | |
return triton is not None and get_device_capability() >= (7, 0) | |
except ImportError: | |
return False | |
def check(cond, msg): | |
if not cond: | |
raise ValueError(msg) | |
def check_bsr_layout(f_name, t): | |
check( | |
t.layout == torch.sparse_bsr, | |
f"{f_name}(): only BSR sparse format is supported for the sparse argument.", | |
) | |
def check_device(f_name, t, device): | |
check( | |
t.device == device and t.device.type == "cuda", | |
f"{f_name}(): all inputs are expected to be on the same GPU device.", | |
) | |
def check_mm_compatible_shapes(f_name, lhs, rhs): | |
check( | |
lhs.dim() >= 2 and rhs.dim() >= 2, | |
f"{f_name}(): all inputs involved in the matrix product are expected to be at least 2D, " | |
f"but got lhs.dim() == {lhs.dim()} and rhs.dim() == {rhs.dim()}." | |
) | |
m, kl = lhs.shape[-2:] | |
kr, n = rhs.shape[-2:] | |
check( | |
kl == kr, | |
f"{f_name}(): arguments' sizes involved in the matrix product are not compatible for matrix multiplication, " | |
f"got lhs.shape[-1] == {kl} which is not equal to rhs.shape[-2] == {kr}.", | |
) | |
def check_dtype(f_name, t, dtype, *additional_dtypes): | |
check( | |
t.dtype == dtype | |
and t.dtype in ((torch.half, torch.bfloat16, torch.float) + tuple(*additional_dtypes)), | |
f"{f_name}(): all inputs are expected to be of the same dtype " | |
f"and one of (half, bfloat16, float32) or {additional_dtypes}, " | |
f"but got dtype == {t.dtype}.", | |
) | |
def check_blocksize(f_name, blocksize): | |
assert len(blocksize) == 2 | |
def is_power_of_two(v): | |
return not (v & (v - 1)) | |
def is_compatible_blocksize(b): | |
res = True | |
for blocksize in b: | |
# Triton loads only blocks which are at least 16 and powers of 2. | |
res = (blocksize >= 16 and is_power_of_two(blocksize)) and res | |
return res | |
check( | |
is_compatible_blocksize(blocksize), | |
f"{f_name}(): sparse inputs' blocksize ({blocksize[0]}, {blocksize[1]}) " | |
"should be at least 16 and a power of 2 in each dimension.", | |
) | |
def make_triton_contiguous(t): | |
if t.stride(-2) > 1 and t.stride(-1) > 1: | |
return t.contiguous() | |
else: | |
return t | |
def broadcast_batch_dims(f_name, *tensors): | |
try: | |
return torch.broadcast_shapes(*(t.shape[:-2] for t in tensors)) | |
except Exception: | |
check(False, f"{f_name}(): inputs' batch dimensions are not broadcastable!") | |
def slicer(dim, slice_range, *tensors): | |
for t in tensors: | |
slices = [slice(None)] * t.dim() | |
slices[dim] = slice_range | |
yield t[slices] | |
def multidim_slicer(dims, slices, *tensors): | |
for t in tensors: | |
s = [slice(None)] * t.dim() | |
for d, d_slice in zip(dims, slices): | |
if d is not None: | |
s[d] = d_slice | |
yield t[s] | |
def ptr_stride_extractor(*tensors): | |
for t in tensors: | |
yield t | |
yield from t.stride() | |
def grid_partitioner(full_grid, grid_blocks, tensor_dims_map): | |
assert 0 <= len(full_grid) <= 3 | |
assert 0 <= len(grid_blocks) <= 3 | |
import itertools | |
def generate_grid_points(): | |
for fg, mg in zip(full_grid, grid_blocks): | |
yield range(0, fg, mg) | |
def generate_sliced_tensors(slices): | |
for t, t_dims in tensor_dims_map.items(): | |
yield next(multidim_slicer(t_dims, slices, t)) | |
for grid_point in itertools.product(*generate_grid_points()): | |
grid = [min(fg - gp, mg) for fg, gp, mg in zip(full_grid, grid_point, grid_blocks)] | |
slices = [slice(gp, gp + g) for gp, g in zip(grid_point, grid)] | |
# grid_points are iterated in a "contiguous" order, i.e. | |
# left dimensions traversed slower than right dimensions. | |
# This order is reversed for CUDA grids. | |
yield grid[::-1], *generate_sliced_tensors(slices) | |
def launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks=None): | |
# cuda_max_grid = (2 ** 31 - 1, 2 ** 16 - 1, 2 ** 16 - 1) | |
cuda_max_grid = (2147483647, 65535, 65535)[::-1] | |
if grid_blocks is None: | |
grid_blocks = cuda_max_grid | |
else: | |
def valid_grid_dim(g, mg): | |
if g is None: | |
return mg | |
else: | |
# grid must be at least 1 and no greater than mg | |
return max(1, min(g, mg)) | |
grid_blocks = tuple( | |
valid_grid_dim(g, mg) for g, mg in zip(grid_blocks, cuda_max_grid) | |
) # type: ignore[assignment] | |
for grid, *sliced_tensors in grid_partitioner(full_grid, grid_blocks, tensor_dims_map): | |
kernel(grid, *sliced_tensors) | |
def prepare_inputs(bsr, *dense_tensors): | |
# Introduce fake batch dimension if not present for convenience. | |
crow_indices = bsr.crow_indices().unsqueeze(0) | |
col_indices = bsr.col_indices().unsqueeze(0) | |
values = make_triton_contiguous(bsr.values().unsqueeze(0)) | |
tensors = [make_triton_contiguous(t.unsqueeze(0)) for t in dense_tensors] | |
# Compute broadcasted batch dimension | |
batch_dims_broadcasted = torch.broadcast_shapes(values.shape[:-3], *(t.shape[:-2] for t in tensors)) | |
# Broadcast batch dimensions and squash | |
def batch_broadcast_and_squash(t, batch_dims, invariant_dims): | |
return t.broadcast_to(batch_dims + invariant_dims).flatten( | |
0, len(batch_dims) - 1 | |
) | |
crow_indices = batch_broadcast_and_squash( | |
crow_indices, batch_dims_broadcasted, (-1,) | |
) | |
col_indices = batch_broadcast_and_squash( | |
col_indices, batch_dims_broadcasted, (-1,) | |
) | |
values = batch_broadcast_and_squash( | |
values, batch_dims_broadcasted, values.shape[-3:] | |
) | |
tensors = [ | |
batch_broadcast_and_squash(t, batch_dims_broadcasted, t.shape[-2:]) for t in tensors | |
] | |
return crow_indices, col_indices, values, *tensors | |
def broadcast_batch_dims_bsr(f_name, bsr, *tensors): | |
batch_shape = broadcast_batch_dims(f_name, bsr, *tensors) | |
crow_indices = bsr.crow_indices().broadcast_to(batch_shape + (-1,)) | |
col_indices = bsr.col_indices().broadcast_to(batch_shape + (-1,)) | |
values = bsr.values().broadcast_to(batch_shape + bsr.values().shape[-3:]) | |
size = batch_shape + bsr.shape[-2:] | |
return torch.sparse_compressed_tensor(crow_indices, col_indices, values, size=size, layout=bsr.layout) | |
# NOTE: this function will ALWAYS create a view | |
def tile_to_blocksize(t, blocksize): | |
*rest, m, n = t.shape | |
new_shape = rest + [ | |
m // blocksize[0], | |
blocksize[0], | |
n // blocksize[1], | |
blocksize[1], | |
] | |
return t.reshape(new_shape).transpose(-3, -2) | |
@triton.jit | |
def _sampled_addmm_kernel( | |
alpha, | |
beta, | |
IS_BETA_ZERO: tl.constexpr, | |
BLOCKSIZE_ROW: tl.constexpr, | |
BLOCKSIZE_COL: tl.constexpr, | |
k, | |
TILE_K: tl.constexpr, | |
values_ptr, | |
values_batch_stride, | |
values_nnz_stride, | |
values_row_block_stride, | |
values_col_block_stride, | |
crow_indices_ptr, | |
crow_indices_batch_stride, | |
crow_indices_stride, | |
col_indices_ptr, | |
col_indices_batch_stride, | |
col_indices_stride, | |
mat1_ptr, | |
mat1_batch_stride, | |
mat1_tiled_row_stride, | |
mat1_tiled_col_stride, | |
mat1_row_block_stride, | |
mat1_col_block_stride, | |
mat2_ptr, | |
mat2_batch_stride, | |
mat2_tiled_row_stride, | |
mat2_tiled_col_stride, | |
mat2_row_block_stride, | |
mat2_col_block_stride, | |
acc_dtype: tl.constexpr, | |
allow_tf32: tl.constexpr, | |
): | |
batch_pid = tl.program_id(axis=1) | |
row_block_pid = tl.program_id(axis=0) | |
crow_indices_offset_ptr = ( | |
crow_indices_ptr | |
+ crow_indices_batch_stride * batch_pid | |
+ crow_indices_stride * row_block_pid | |
) | |
nnz_offset = tl.load(crow_indices_offset_ptr) | |
nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride) | |
# Compute nnz for the row with number row_block_pid. | |
# If it is zero, skip the row. | |
row_nnz = nnz_offset_next - nnz_offset | |
if row_nnz == 0: | |
return | |
row_block_arange = tl.arange(0, BLOCKSIZE_ROW) | |
col_block_arange = tl.arange(0, BLOCKSIZE_COL) | |
# Pointers are set to the first block of the current row. | |
values_block_ptrs = ( | |
values_ptr | |
+ values_batch_stride * batch_pid | |
+ values_nnz_stride * nnz_offset | |
+ values_row_block_stride * row_block_arange[:, None] | |
+ values_col_block_stride * col_block_arange[None, :] | |
) | |
col_index_nnz_ptr = ( | |
col_indices_ptr | |
+ col_indices_batch_stride * batch_pid | |
+ col_indices_stride * nnz_offset | |
) | |
# Advance mat1 to the current tiled row, ignore columns. | |
mat1_block_ptrs = ( | |
mat1_ptr | |
+ mat1_batch_stride * batch_pid | |
+ mat1_tiled_row_stride * row_block_pid | |
+ mat1_row_block_stride * row_block_arange[:, None] | |
) | |
# Advance mat2 in batch and block col dimension. | |
mat2_block_ptrs = ( | |
mat2_ptr | |
+ mat2_batch_stride * batch_pid | |
+ mat2_col_block_stride * col_block_arange[None, :] | |
) | |
k_tile_arange = tl.arange(0, TILE_K) | |
for _ in range(row_nnz): | |
acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype) | |
# find column block index | |
col_block = tl.load(col_index_nnz_ptr) | |
for k_tile in range(0, k, TILE_K): | |
k_offsets = k_tile + k_tile_arange | |
mask_k = k_offsets < k | |
mat1_block = tl.load( | |
mat1_block_ptrs | |
+ mat1_col_block_stride * k_offsets[None, :], | |
mask=mask_k[None, :], other=0.0 | |
) | |
mat2_block = tl.load( | |
mat2_block_ptrs | |
+ mat2_tiled_col_stride * col_block | |
+ mat2_row_block_stride * k_offsets[:, None], | |
mask=mask_k[:, None], other=0.0 | |
) | |
acc_block += tl.dot(mat1_block, mat2_block, allow_tf32=allow_tf32) | |
if IS_BETA_ZERO: | |
acc_block *= alpha | |
else: | |
acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs) | |
# write result | |
tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty)) | |
# advance val/col_index ptrs to the next block in the row. | |
values_block_ptrs += values_nnz_stride | |
col_index_nnz_ptr += col_indices_stride | |
def _run_sampled_addmm_kernel( | |
alpha, beta, is_beta_zero, | |
blocksize, k, tile_k, | |
values, crow_indices, col_indices, | |
mat1, mat2, | |
max_grid | |
): | |
n_batches = values.size(0) | |
n_block_rows = crow_indices.size(-1) - 1 | |
full_grid = (n_batches, n_block_rows) | |
if max_grid is not None: | |
grid_blocks = tuple(max_grid[:2][::-1]) + (None,) * (2 - len(max_grid[:2])) | |
else: | |
grid_blocks = None | |
tensor_dims_map = { | |
values: (0, None), | |
crow_indices: (0, -1), | |
col_indices: (0, None), | |
mat1: (0, -4), | |
mat2: (0, None), | |
} | |
if values.dtype in (torch.half, torch.bfloat16): | |
acc_dtype = tl.float32 | |
allow_tf32 = True | |
else: | |
acc_dtype = tl.float64 | |
allow_tf32 = False | |
def kernel(grid, *sliced_tensors): | |
#breakpoint() | |
#going in everything looks good | |
_sampled_addmm_kernel[grid]( | |
alpha, beta, is_beta_zero, | |
*blocksize, k, tile_k, | |
*ptr_stride_extractor(*sliced_tensors), | |
acc_dtype=acc_dtype, | |
allow_tf32=allow_tf32, | |
num_stages=1, | |
num_warps=4 | |
) | |
#breakpoint() | |
#on the way out sliced_tensors[0] (sparse array values has an alignment error) | |
launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks) | |
def sampled_addmm( | |
input: torch.Tensor, | |
mat1: torch.Tensor, | |
mat2: torch.Tensor, | |
*, | |
beta=1.0, | |
alpha=1.0, | |
out: Optional[torch.Tensor] = None, | |
skip_checks: bool = False, | |
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, | |
): | |
f_name = "sampled_addmm" | |
check_bsr_layout(f_name, input) | |
input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2) | |
if not skip_checks: | |
check_device(f_name, mat1, input.device) | |
check_device(f_name, mat2, input.device) | |
if beta != 0.0 and input.dtype is torch.bool: | |
check( | |
False, | |
f"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed." | |
) | |
if input.dtype is not torch.bool: | |
check_dtype(f_name, mat1, input.dtype) | |
check_dtype(f_name, mat2, input.dtype) | |
else: | |
check_dtype(f_name, mat1, mat2.dtype) | |
check_mm_compatible_shapes(f_name, mat1, mat2) | |
if out is not None: | |
check_bsr_layout(f_name, out) | |
check_device(f_name, out, mat1.device) | |
check_dtype(f_name, out, input.dtype) | |
check( | |
out.shape == input_broadcasted.shape | |
and out._nnz() == input._nnz(), | |
f"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} " | |
f"and with nnz equal to {input_broadcasted._nnz()} " | |
f"but got out.shape = {out.shape} and out.nnz = {out._nnz()}" | |
) | |
if out is None: | |
out = input_broadcasted.to(mat1.dtype, copy=True) | |
else: | |
out.copy_(input_broadcasted) | |
if out.numel() == 0 or out._nnz() == 0: | |
return out | |
blocksize = out.values().shape[-2:] | |
m = mat1.size(-2) | |
n = mat2.size(-1) | |
k = mat1.size(-1) | |
# NOTE: (m, 0) @ (0, n) == zeros(m, n) | |
if alpha == 0.0 or k == 0: | |
out.values().mul_(beta) | |
return out | |
# prepare inputs by reshaping them to be kernel-compatible | |
out_backup = out | |
crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2) | |
mat1 = tile_to_blocksize(mat1, (blocksize[0], k)) | |
mat2 = tile_to_blocksize(mat2, (k, blocksize[1])) | |
tile_k = max(*blocksize) | |
_run_sampled_addmm_kernel( | |
alpha, beta, beta == 0.0, | |
blocksize, k, tile_k, | |
values, crow_indices, col_indices, | |
mat1, mat2, | |
max_grid | |
) | |
# If nnz x block strides are not the same in out_backup.values and values, | |
# it means that out_backup.values and values are not the views of each other, | |
# so we have to copy. | |
if out_backup.values().stride()[-3:] != values.stride()[-3:]: | |
out_backup.values().copy_(values.reshape(out_backup.values().shape)) | |
return out_backup | |
################ end triton code from pytorch ########################## | |
################ reproducer below ###################################### | |
from functools import partial | |
from torch.testing import make_tensor | |
DEVICE = 'cuda' | |
def make_inputs(dtype, blocksize, m, n, k): | |
new_tensor = partial(make_tensor, device=DEVICE, dtype=dtype, low=0.3, high=1.2) | |
mask = new_tensor(m, n).tril_() | |
mask_bsr = mask.to_sparse_bsr(blocksize) | |
m1 = new_tensor((m, k)) | |
m2 = new_tensor((n, k)) | |
m3 = new_tensor((n, k)) | |
return mask_bsr, m1, m2, m3 | |
if __name__ == "__main__": | |
mask_bsr, m1, m2, m3 = make_inputs(torch.bfloat16, 16, 64, 64, 64) | |
result = sampled_addmm(mask_bsr, m1, m2.transpose(-2,-1)) | |
print(result) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
See here for old/new pin hashes.
Old pin executes cleanly, new pin produces...