Created
May 11, 2025 21:24
-
-
Save scottt/5a1a3c270c4cef4b65964a9206cface7 to your computer and use it in GitHub Desktop.
aotriton attn_fwd C++ test program
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <vector> | |
#include <aotriton/config.h> | |
#include <aotriton/dtypes.h> | |
#include <aotriton/util.h> | |
#include <aotriton/flash.h> | |
#include <aotriton/runtime.h> | |
// Helper function to create dummy data for TensorViews, using uint16_t to fill kFloat16 values | |
std::vector<uint16_t> | |
create_dummy_data4(const std::array<uint64_t, 4>& shape, uint16_t value) { | |
size_t total_elements = 1; | |
for (uint64_t dim : shape) { | |
total_elements *= dim; | |
} | |
return std::vector<uint16_t>(total_elements, value); | |
} | |
std::vector<uint16_t> | |
create_dummy_data2(const std::array<uint64_t, 2>& shape, uint16_t value) { | |
size_t total_elements = 1; | |
for (uint64_t dim : shape) { | |
total_elements *= dim; | |
} | |
return std::vector<uint16_t>(total_elements, value); | |
} | |
template <typename T> | |
std::vector<T> create_dummy_data(const std::vector<uint64_t>& shape) { | |
size_t total_elements = 1; | |
for (uint64_t dim : shape) { | |
total_elements *= dim; | |
} | |
return std::vector<T>(total_elements, T{}); | |
} | |
std::array<uint64_t, 4> | |
strides_from_shape4(const std::array<uint64_t, 4>& shape) { | |
std::array<uint64_t, 4> strides = { | |
shape[2] * shape[1] * shape[0], | |
shape[1] * shape[0], | |
shape[0], | |
1, }; | |
return strides; | |
} | |
std::array<uint64_t, 2> | |
strides_from_shape2(const std::array<uint64_t, 2>& shape) { | |
std::array<uint64_t, 2> strides = { | |
shape[0], | |
1, }; | |
return strides; | |
} | |
int main() { | |
// Dimensions for the tensors | |
uint64_t batch_size = 1; | |
uint64_t num_heads = 32; | |
uint64_t seqlen_q = 64; | |
uint64_t seqlen_k = 64; | |
uint64_t head_size = 16; | |
AOTRITON_NS::DType tensor_dtype = AOTRITON_NS::kFloat16; | |
std::array<uint64_t, 4> q_shape = {batch_size, num_heads, seqlen_q, head_size}; | |
auto q_data = create_dummy_data4(q_shape, (uint16_t)(0x3a51)); // random value | |
AOTRITON_NS::TensorView<4> q(reinterpret_cast<intptr_t>(q_data.data()), q_shape, strides_from_shape4(q_shape), tensor_dtype); | |
std::array<uint64_t, 4> k_shape = {batch_size, num_heads, seqlen_k, head_size}; | |
auto k_data = create_dummy_data4(k_shape, 0x3bb4); | |
AOTRITON_NS::TensorView<4> k(reinterpret_cast<intptr_t>(k_data.data()), k_shape, strides_from_shape4(k_shape), tensor_dtype); | |
std::array<uint64_t, 4> v_shape = {batch_size, num_heads, seqlen_k, head_size}; | |
auto v_data = create_dummy_data4(v_shape, 0x3b01); | |
AOTRITON_NS::TensorView<4> v(reinterpret_cast<intptr_t>(v_data.data()), v_shape, strides_from_shape4(v_shape), tensor_dtype); | |
std::array<uint64_t, 4> b_shape = {0, 0, 0, 0}; // Bias tensor | |
auto b_data = create_dummy_data4(b_shape, 0); | |
AOTRITON_NS::TensorView<4> b(reinterpret_cast<intptr_t>(b_data.data()), b_shape, strides_from_shape4(b_shape), tensor_dtype); | |
#if 0 | |
std::array<uint64_t, 2> softmax_lse_shape = {num_heads, seqlen_q}; | |
auto softmax_lse_data = create_dummy_data2(softmax_lse_shape, 0x0101); | |
AOTRITON_NS::TensorView<2> softmax_lse(reinterpret_cast<intptr_t>(softmax_lse_data.data()), softmax_lse_shape, | |
strides_from_shape2(softmax_lse_shape), aotriton::kFloat32); | |
#endif | |
auto softmax_lse_data = create_dummy_data<float>({batch_size, num_heads}); | |
std::array<uint64_t, 2> softmax_lse_sizes = {(uint64_t)batch_size, (uint64_t)num_heads}; | |
std::array<uint64_t, 2> softmax_lse_strides = {(uint64_t)num_heads, 1}; | |
AOTRITON_NS::TensorView<2> softmax_lse(reinterpret_cast<intptr_t>(softmax_lse_data.data()), softmax_lse_sizes, softmax_lse_strides, tensor_dtype); | |
std::array<uint64_t, 4> out_shape = {batch_size, num_heads, seqlen_q, head_size}; | |
auto out_data = create_dummy_data4(out_shape, 0); | |
AOTRITON_NS::TensorView<4> out(reinterpret_cast<intptr_t>(out_data.data()), out_shape, | |
strides_from_shape4(out_shape), tensor_dtype); | |
std::array<uint64_t, 4> encoded_softmax_shape = {batch_size, num_heads, seqlen_q, head_size}; | |
auto encoded_softmax_data = create_dummy_data4({0, 0, 0, 0}, 0); | |
AOTRITON_NS::TensorView<4> encoded_softmax(reinterpret_cast<intptr_t>(encoded_softmax_data.data()), | |
encoded_softmax_shape, strides_from_shape4(encoded_softmax_shape), tensor_dtype); | |
// Scalar parameters | |
float sm_scale = 0.25f; | |
float dropout_p = 0.0f; | |
bool is_causal = true; | |
// Philox RNG parameters (TensorView<0> are scalar-like tensors) | |
// For simplicity, we'll use null data pointers for T0, as their actual usage | |
// might involve specific GPU scalar handling or be passed by value if the API allows. | |
// The AOTriton API expects these to be TensorView<0>. | |
uint64_t seed_val = 1; | |
uint64_t offset_val = 72340172838076673; | |
AOTRITON_NS::DType scalar_dtype = AOTRITON_NS::kUInt64; | |
AOTRITON_NS::TensorView<0> philox_seed(reinterpret_cast<intptr_t>(&seed_val), aotriton::kInt64); | |
AOTRITON_NS::TensorView<0> philox_offset1(reinterpret_cast<intptr_t>(&offset_val), aotriton::kInt64); | |
int64_t philox_offset2 = 0; | |
AOTRITON_NS::TensorView<0> philox_seed_output = AOTRITON_NS::TensorView<0>::get_null_tensor(aotriton::kUInt64); | |
AOTRITON_NS::TensorView<0> philox_offset_output = AOTRITON_NS::TensorView<0>::get_null_tensor(aotriton::kUInt64); | |
int32_t atoimc_for_causal_val = 0; | |
AOTRITON_NS::TensorView<0> atomic_for_causal(reinterpret_cast<intptr_t>(&atomic_for_causal), aotriton::kInt32); | |
AOTRITON_NS::Stream stream = nullptr; // Default stream | |
// Extra arguments (optional) | |
AOTRITON_NS::v2::flash::FwdExtraArguments *extargs = nullptr; | |
std::cout << "Calling AOTRITON_NS::v2::flash::attn_fwd..." << std::endl; | |
// Call the attn_fwd function | |
hipError_t result = AOTRITON_NS::v2::flash::attn_fwd( | |
q, | |
k, | |
v, | |
b, | |
sm_scale, | |
softmax_lse, | |
out, | |
dropout_p, | |
philox_seed, | |
philox_offset1, | |
philox_offset2, | |
philox_seed_output, | |
philox_offset_output, | |
encoded_softmax, | |
is_causal, | |
atomic_for_causal, | |
stream, | |
extargs | |
); | |
if (result == hipSuccess) { | |
std::cout << "AOTRITON_NS::v2::flash::attn_fwd call succeeded." << std::endl; | |
// In a real application, you would check the contents of 'out' and 'softmax_lse' | |
} else { | |
std::cerr << "AOTRITON_NS::v2::flash::attn_fwd call failed with error code: " << hipGetErrorString(result) << std::endl; | |
} | |
// Example of calling check_gpu | |
hipError_t err; | |
err = AOTRITON_NS::v2::flash::check_gpu(stream); | |
if (err == hipSuccess) { | |
std::cout << "check_gpu successful." << std::endl; | |
} else { | |
std::cerr << "check_gpu failed with error code: " << hipGetErrorString(err) << std::endl; | |
} | |
std::cout << "Synchronizing stream..." << std::endl; | |
err = hipStreamSynchronize(stream.native()); | |
std::cout << "Stream synchronized. All work on this stream is complete." << std::endl; | |
if ((err = hipStreamDestroy(stream.native())) < 0) { | |
abort(); | |
} | |
return 0; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# --- Configuration --- | |
# Assuming 'clang' is in your PATH. If not, provide the full path. | |
$clangExecutable = "clang" | |
# Example: $clangExecutable = "C:\Program Files\LLVM\bin\clang.exe" # Or Linux-style path if in WSL/MinGW | |
$ROCM_PREFIX = "/o/r-st-gfx1151/build/dist/rocm" | |
$AOTRITON_PREFIX = "/o/aotriton/build/install_dir" | |
# --- Compiler Flags --- | |
# -I<path> is a single argument for clang, so this is fine as a single string. | |
$HIP_CFLAGS = "-I$ROCM_PREFIX/include" | |
# These are two distinct arguments for clang, so store them as an array. | |
$AOTRITON_CFLAGS = @( | |
"-D__HIP_PLATFORM_AMD__", | |
"-I$AOTRITON_PREFIX/include" | |
) | |
# --- Linker Flags and Libraries --- | |
# These are multiple distinct arguments. | |
$AOTRITON_LIBS = @( | |
"-L$AOTRITON_PREFIX/lib", | |
"-laotriton_v2.lib", | |
"-L$ROCM_PREFIX/lib", | |
"-lamdhip64" # Tells linker to find libamdhip64.so or amdhip64.lib etc. | |
) | |
# --- Source and Output Files --- | |
$sourceFile = "attn-fwd-test.cpp" | |
$outputFile = "attn-fwd-test.exe" | |
# --- Construct the full argument list for clang --- | |
# This is the most robust way in PowerShell to build arguments for external commands. | |
$clangArgs = @() | |
$clangArgs += $HIP_CFLAGS # Adds the string "-I/path/to/rocm/include" as one argument | |
$clangArgs += $AOTRITON_CFLAGS # Adds elements of the array as separate arguments | |
$clangArgs += $sourceFile | |
$clangArgs += $AOTRITON_LIBS # Adds elements of the array as separate arguments | |
$clangArgs += "-o" | |
$clangArgs += $outputFile | |
# Optional: Add other flags like -Wall, -g, -std=c++17 etc. | |
# $clangArgs += "-Wall" | |
# $clangArgs += "-g" | |
# $clangArgs += "-std=c++17" | |
# --- Execute Clang --- | |
Write-Host "Attempting to compile with clang..." | |
# Display the command that will be executed (for debugging/verification) | |
# PowerShell's default stringification of an array for Write-Host might not look exactly like the shell command, | |
# but the splatting operator (@) ensures arguments are passed correctly. | |
Write-Host "Command: & '$clangExecutable' $($clangArgs -join ' ')" # This join is just for display | |
try { | |
# Use the call operator (&) and splatting (@) | |
& $clangExecutable @clangArgs | |
if ($LASTEXITCODE -eq 0) { | |
Write-Host "Compilation successful: '$outputFile' created." -ForegroundColor Green | |
} else { | |
Write-Error "Clang compilation failed. Exit code: $LASTEXITCODE" | |
} | |
} catch { | |
Write-Error "An error occurred while trying to execute clang: $($_.Exception.Message)" | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PROGRAMS := $(basename $(wildcard *.cpp)) | |
ROCM_PREFIX := $(HOME)/therock-upstream-output/build/dist/rocm | |
HIP_CFLAGS := -I$(ROCM_PREFIX)/include | |
AOTRITON_CFLAGS := -D__HIP_PLATFORM_AMD__ -I$(HOME)/aotriton-output/build/install_dir/include | |
AOTRITON_LIBS := -L$(HOME)/aotriton-output/build/install_dir/lib -laotriton_v2 -L$(ROCM_PREFIX)/lib -lamdhip64 | |
%: %.cpp | |
clang++ $(HIP_CFLAGS) $(AOTRITON_CFLAGS) $< $(AOTRITON_LIBS) -o $@ | |
.PHONY: all | |
all: $(PROGRAMS) | |
.PHONY: clean | |
clean: | |
rm -f $(PROGRAMS) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment