Skip to content

Instantly share code, notes, and snippets.

View ita9naiwa's full-sized avatar

Hyunsung Lee ita9naiwa

View GitHub Profile
import torch
import triton
import triton.language as tl
@triton.jit
def mxfp_matmul(
a_ptr, b_ptr, output_ptr,
a_scale, b_scale,
M, N, K,
stride_scale: tl.constexpr,
#!/usr/bin/env python3
import torch
import triton
import triton.language as tl
@triton.jit
def scaled_dot_kernel(
# Pointers to matrices
a_ptr, b_ptr, output_ptr,