Skip to content

Instantly share code, notes, and snippets.

@minjang
Created May 23, 2024 01:10
Show Gist options
  • Save minjang/9edf509eb9268920a916a2d56312a285 to your computer and use it in GitHub Desktop.
Save minjang/9edf509eb9268920a916a2d56312a285 to your computer and use it in GitHub Desktop.
A quick patch to make triton-cpu runnable for github.com/ienkovich/triton-cpu/tree/ienkovich/change-cast-test-size
diff --git a/include/triton/Conversion/CMakeLists.txt b/include/triton/Conversion/CMakeLists.txt
index ae31ac93..691104f3 100644
--- a/include/triton/Conversion/CMakeLists.txt
+++ b/include/triton/Conversion/CMakeLists.txt
@@ -1,4 +1,4 @@
-add_subdirectory(TritonCPUToLLVM)
+# add_subdirectory(TritonCPUToLLVM)
add_subdirectory(TritonGPUToLLVM)
-add_subdirectory(TritonToTritonCPU)
+# add_subdirectory(TritonToTritonCPU)
add_subdirectory(TritonToTritonGPU)
diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt
index 83db4ae4..1998a41a 100644
--- a/lib/Conversion/CMakeLists.txt
+++ b/lib/Conversion/CMakeLists.txt
@@ -1,4 +1,4 @@
-#add_subdirectory(TritonToTritonCPU)
+# add_subdirectory(TritonToTritonCPU)
add_subdirectory(TritonToTritonGPU)
-#add_subdirectory(TritonCPUToLLVM)
+# add_subdirectory(TritonCPUToLLVM)
add_subdirectory(TritonGPUToLLVM)
diff --git a/python/src/passes.cc b/python/src/passes.cc
index df7d9faa..716f2623 100644
--- a/python/src/passes.cc
+++ b/python/src/passes.cc
@@ -6,7 +6,7 @@
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/Membar.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
-#include "triton/Conversion/TritonToTritonCPU/Passes.h"
+// #include "triton/Conversion/TritonToTritonCPU/Passes.h"
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py
index 6eb7aa29..7684d7f1 100644
--- a/python/triton/compiler/compiler.py
+++ b/python/triton/compiler/compiler.py
@@ -224,6 +224,7 @@ def filter_traceback(e: BaseException):
def compile(src, target=None, options=None):
+ print(f"Compiling from triton-cpu-ienkovich...")
if target is None:
target = driver.active.get_current_target()
assert isinstance(target, GPUTarget), "target must be of GPUTarget type"
diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py
index d7baeb28..bf9c740a 100644
--- a/python/triton/runtime/build.py
+++ b/python/triton/runtime/build.py
@@ -18,18 +18,19 @@ def quiet():
sys.stdout, sys.stderr = old_stdout, old_stderr
-def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
+def _build(name, src, srcdir, library_dirs, include_dirs, libraries, cxx=False):
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
# try to avoid setuptools if possible
- cc = os.environ.get("CC")
+ cc = os.environ.get("CC" if not cxx else "CXX")
if cc is None:
# TODO: support more things here.
- clang = shutil.which("clang")
- gcc = shutil.which("gcc")
+ clang = shutil.which("clang" if not cxx else "clang++")
+ gcc = shutil.which("gcc" if not cxx else "g++")
cc = gcc if gcc is not None else clang
if cc is None:
- raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
+ raise RuntimeError("Failed to find C or C++ compiler. Please specify via CC environment variable.")
+
# This function was renamed and made public in Python 3.10
if hasattr(sysconfig, 'get_default_scheme'):
scheme = sysconfig.get_default_scheme()
diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py
index 1c1900a0..207ec22c 100644
--- a/python/tutorials/01-vector-add.py
+++ b/python/tutorials/01-vector-add.py
@@ -23,6 +23,39 @@ import torch
import triton
import triton.language as tl
[email protected]
+def add_kernel_block_ptr(in1_ptr, in2_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
+ id = tl.program_id(0)
+ offs = id * BLOCK_SIZE
+ in1_block_ptr = tl.make_block_ptr(
+ base=in1_ptr,
+ shape=(BLOCK_SIZE,),
+ strides=(1,),
+ offsets=(offs,),
+ block_shape=(BLOCK_SIZE,),
+ order=(0,),
+ )
+ in2_block_ptr = tl.make_block_ptr(
+ base=in2_ptr,
+ shape=(BLOCK_SIZE,),
+ strides=(1,),
+ offsets=(offs,),
+ block_shape=(BLOCK_SIZE,),
+ order=(0,),
+ )
+ out_block_ptr = tl.make_block_ptr(
+ base=out_ptr,
+ shape=(BLOCK_SIZE,),
+ strides=(1,),
+ offsets=(offs,),
+ block_shape=(BLOCK_SIZE,),
+ order=(0,),
+ )
+ val1 = tl.load(in1_block_ptr, boundary_check=())
+ val2 = tl.load(in2_block_ptr, boundary_check=())
+ val = val1 + val2
+ tl.store(out_block_ptr, val, boundary_check=())
+
@triton.jit
def add_kernel(x_ptr, # *Pointer* to first input vector.
@@ -42,14 +75,14 @@ def add_kernel(x_ptr, # *Pointer* to first input vector.
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses.
- mask = offsets < n_elements
+ # mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size.
- x = tl.load(x_ptr + offsets, mask=mask)
- y = tl.load(y_ptr + offsets, mask=mask)
+ x = tl.load(x_ptr + offsets)
+ y = tl.load(y_ptr + offsets)
output = x + y
# Write x + y back to DRAM.
- tl.store(output_ptr + offsets, output, mask=mask)
+ tl.store(output_ptr + offsets, output)
# %%
@@ -57,10 +90,10 @@ def add_kernel(x_ptr, # *Pointer* to first input vector.
# and (2) enqueue the above kernel with appropriate grid/block sizes:
-def add(x: torch.Tensor, y: torch.Tensor):
+def add_triton_cpu(x: torch.Tensor, y: torch.Tensor, use_block_ptr):
# We need to preallocate the output.
output = torch.empty_like(x)
- assert x.is_cuda and y.is_cuda and output.is_cuda
+ #assert not x.is_cuda and not y.is_cuda and not output.is_cuda
n_elements = output.numel()
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
@@ -70,7 +103,10 @@ def add(x: torch.Tensor, y: torch.Tensor):
# - Each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
# - Don't forget to pass meta-parameters as keywords arguments.
- add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
+ if use_block_ptr:
+ add_kernel_block_ptr[grid](x, y, output, BLOCK_SIZE=128)
+ else:
+ add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=128)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point.
return output
@@ -80,12 +116,16 @@ def add(x: torch.Tensor, y: torch.Tensor):
# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:
torch.manual_seed(0)
-size = 98432
-x = torch.rand(size, device='cuda')
-y = torch.rand(size, device='cuda')
+size = 128
+x = torch.rand(size, device='cpu')
+y = torch.rand(size, device='cpu')
output_torch = x + y
-output_triton = add(x, y)
print(output_torch)
+output_triton = add_triton_cpu(x, y, True)
+print(output_triton)
+print(f'The maximum difference between torch and triton_block_ptr is '
+ f'{torch.max(torch.abs(output_torch - output_triton))}')
+output_triton = add_triton_cpu(x, y, False)
print(output_triton)
print(f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}')
@@ -130,4 +170,4 @@ def benchmark(size, provider):
# %%
# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
# `save_path='/path/to/results/' to save them to disk along with raw CSV data:
-benchmark.run(print_data=True, show_plots=True)
+# benchmark.run(print_data=True, show_plots=True)
diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py
index 357b5f44..13a2d006 100644
--- a/third_party/cpu/backend/compiler.py
+++ b/third_party/cpu/backend/compiler.py
@@ -137,8 +137,10 @@ class CPUBackend(BaseBackend):
@staticmethod
def make_bc(src, metadata, options):
if os.environ.get("TRITON_CPU_ASM_DUMP", "0") == "1":
- print("********** Module ASM **********")
- print(llvm.translate_to_host_asm(src))
+ file_name = f"{metadata['name']}.asm"
+ print(f"********** Writing .asm to {file_name} **********")
+ with open(file_name, "w") as f:
+ f.write(llvm.translate_to_host_asm(src))
ret = llvm.translate_to_bc(src)
return ret
diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py
index 743684d2..14216725 100644
--- a/third_party/cpu/backend/driver.py
+++ b/third_party/cpu/backend/driver.py
@@ -74,6 +74,7 @@ libraries = [
"LLVMSupport",
"LLVMDemangle",
"stdc++",
+ "z",
]
@@ -86,7 +87,7 @@ def compile_module_from_src(src, name):
src_path = os.path.join(tmpdir, "main.cpp")
with open(src_path, "w") as f:
f.write(src)
- so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries)
+ so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries, True)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
import importlib.util
diff --git a/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt
index 64b36523..0936dff1 100644
--- a/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt
+++ b/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt
@@ -1,3 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUToLLVM)
-add_public_tablegen_target(TritonCPUConversionPassIncGen)
+add_public_tablegen_target(TritonCPUToLLVMConversionPassIncGen)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment