Skip to content

Instantly share code, notes, and snippets.

@scottt
Created May 11, 2025 21:24
Show Gist options
  • Save scottt/5a1a3c270c4cef4b65964a9206cface7 to your computer and use it in GitHub Desktop.
Save scottt/5a1a3c270c4cef4b65964a9206cface7 to your computer and use it in GitHub Desktop.
aotriton attn_fwd C++ test program
#include <iostream>
#include <vector>
#include <aotriton/config.h>
#include <aotriton/dtypes.h>
#include <aotriton/util.h>
#include <aotriton/flash.h>
#include <aotriton/runtime.h>
// Helper function to create dummy data for TensorViews, using uint16_t to fill kFloat16 values
std::vector<uint16_t>
create_dummy_data4(const std::array<uint64_t, 4>& shape, uint16_t value) {
size_t total_elements = 1;
for (uint64_t dim : shape) {
total_elements *= dim;
}
return std::vector<uint16_t>(total_elements, value);
}
std::vector<uint16_t>
create_dummy_data2(const std::array<uint64_t, 2>& shape, uint16_t value) {
size_t total_elements = 1;
for (uint64_t dim : shape) {
total_elements *= dim;
}
return std::vector<uint16_t>(total_elements, value);
}
template <typename T>
std::vector<T> create_dummy_data(const std::vector<uint64_t>& shape) {
size_t total_elements = 1;
for (uint64_t dim : shape) {
total_elements *= dim;
}
return std::vector<T>(total_elements, T{});
}
std::array<uint64_t, 4>
strides_from_shape4(const std::array<uint64_t, 4>& shape) {
std::array<uint64_t, 4> strides = {
shape[2] * shape[1] * shape[0],
shape[1] * shape[0],
shape[0],
1, };
return strides;
}
std::array<uint64_t, 2>
strides_from_shape2(const std::array<uint64_t, 2>& shape) {
std::array<uint64_t, 2> strides = {
shape[0],
1, };
return strides;
}
int main() {
// Dimensions for the tensors
uint64_t batch_size = 1;
uint64_t num_heads = 32;
uint64_t seqlen_q = 64;
uint64_t seqlen_k = 64;
uint64_t head_size = 16;
AOTRITON_NS::DType tensor_dtype = AOTRITON_NS::kFloat16;
std::array<uint64_t, 4> q_shape = {batch_size, num_heads, seqlen_q, head_size};
auto q_data = create_dummy_data4(q_shape, (uint16_t)(0x3a51)); // random value
AOTRITON_NS::TensorView<4> q(reinterpret_cast<intptr_t>(q_data.data()), q_shape, strides_from_shape4(q_shape), tensor_dtype);
std::array<uint64_t, 4> k_shape = {batch_size, num_heads, seqlen_k, head_size};
auto k_data = create_dummy_data4(k_shape, 0x3bb4);
AOTRITON_NS::TensorView<4> k(reinterpret_cast<intptr_t>(k_data.data()), k_shape, strides_from_shape4(k_shape), tensor_dtype);
std::array<uint64_t, 4> v_shape = {batch_size, num_heads, seqlen_k, head_size};
auto v_data = create_dummy_data4(v_shape, 0x3b01);
AOTRITON_NS::TensorView<4> v(reinterpret_cast<intptr_t>(v_data.data()), v_shape, strides_from_shape4(v_shape), tensor_dtype);
std::array<uint64_t, 4> b_shape = {0, 0, 0, 0}; // Bias tensor
auto b_data = create_dummy_data4(b_shape, 0);
AOTRITON_NS::TensorView<4> b(reinterpret_cast<intptr_t>(b_data.data()), b_shape, strides_from_shape4(b_shape), tensor_dtype);
#if 0
std::array<uint64_t, 2> softmax_lse_shape = {num_heads, seqlen_q};
auto softmax_lse_data = create_dummy_data2(softmax_lse_shape, 0x0101);
AOTRITON_NS::TensorView<2> softmax_lse(reinterpret_cast<intptr_t>(softmax_lse_data.data()), softmax_lse_shape,
strides_from_shape2(softmax_lse_shape), aotriton::kFloat32);
#endif
auto softmax_lse_data = create_dummy_data<float>({batch_size, num_heads});
std::array<uint64_t, 2> softmax_lse_sizes = {(uint64_t)batch_size, (uint64_t)num_heads};
std::array<uint64_t, 2> softmax_lse_strides = {(uint64_t)num_heads, 1};
AOTRITON_NS::TensorView<2> softmax_lse(reinterpret_cast<intptr_t>(softmax_lse_data.data()), softmax_lse_sizes, softmax_lse_strides, tensor_dtype);
std::array<uint64_t, 4> out_shape = {batch_size, num_heads, seqlen_q, head_size};
auto out_data = create_dummy_data4(out_shape, 0);
AOTRITON_NS::TensorView<4> out(reinterpret_cast<intptr_t>(out_data.data()), out_shape,
strides_from_shape4(out_shape), tensor_dtype);
std::array<uint64_t, 4> encoded_softmax_shape = {batch_size, num_heads, seqlen_q, head_size};
auto encoded_softmax_data = create_dummy_data4({0, 0, 0, 0}, 0);
AOTRITON_NS::TensorView<4> encoded_softmax(reinterpret_cast<intptr_t>(encoded_softmax_data.data()),
encoded_softmax_shape, strides_from_shape4(encoded_softmax_shape), tensor_dtype);
// Scalar parameters
float sm_scale = 0.25f;
float dropout_p = 0.0f;
bool is_causal = true;
// Philox RNG parameters (TensorView<0> are scalar-like tensors)
// For simplicity, we'll use null data pointers for T0, as their actual usage
// might involve specific GPU scalar handling or be passed by value if the API allows.
// The AOTriton API expects these to be TensorView<0>.
uint64_t seed_val = 1;
uint64_t offset_val = 72340172838076673;
AOTRITON_NS::DType scalar_dtype = AOTRITON_NS::kUInt64;
AOTRITON_NS::TensorView<0> philox_seed(reinterpret_cast<intptr_t>(&seed_val), aotriton::kInt64);
AOTRITON_NS::TensorView<0> philox_offset1(reinterpret_cast<intptr_t>(&offset_val), aotriton::kInt64);
int64_t philox_offset2 = 0;
AOTRITON_NS::TensorView<0> philox_seed_output = AOTRITON_NS::TensorView<0>::get_null_tensor(aotriton::kUInt64);
AOTRITON_NS::TensorView<0> philox_offset_output = AOTRITON_NS::TensorView<0>::get_null_tensor(aotriton::kUInt64);
int32_t atoimc_for_causal_val = 0;
AOTRITON_NS::TensorView<0> atomic_for_causal(reinterpret_cast<intptr_t>(&atomic_for_causal), aotriton::kInt32);
AOTRITON_NS::Stream stream = nullptr; // Default stream
// Extra arguments (optional)
AOTRITON_NS::v2::flash::FwdExtraArguments *extargs = nullptr;
std::cout << "Calling AOTRITON_NS::v2::flash::attn_fwd..." << std::endl;
// Call the attn_fwd function
hipError_t result = AOTRITON_NS::v2::flash::attn_fwd(
q,
k,
v,
b,
sm_scale,
softmax_lse,
out,
dropout_p,
philox_seed,
philox_offset1,
philox_offset2,
philox_seed_output,
philox_offset_output,
encoded_softmax,
is_causal,
atomic_for_causal,
stream,
extargs
);
if (result == hipSuccess) {
std::cout << "AOTRITON_NS::v2::flash::attn_fwd call succeeded." << std::endl;
// In a real application, you would check the contents of 'out' and 'softmax_lse'
} else {
std::cerr << "AOTRITON_NS::v2::flash::attn_fwd call failed with error code: " << hipGetErrorString(result) << std::endl;
}
// Example of calling check_gpu
hipError_t err;
err = AOTRITON_NS::v2::flash::check_gpu(stream);
if (err == hipSuccess) {
std::cout << "check_gpu successful." << std::endl;
} else {
std::cerr << "check_gpu failed with error code: " << hipGetErrorString(err) << std::endl;
}
std::cout << "Synchronizing stream..." << std::endl;
err = hipStreamSynchronize(stream.native());
std::cout << "Stream synchronized. All work on this stream is complete." << std::endl;
if ((err = hipStreamDestroy(stream.native())) < 0) {
abort();
}
return 0;
}
# --- Configuration ---
# Assuming 'clang' is in your PATH. If not, provide the full path.
$clangExecutable = "clang"
# Example: $clangExecutable = "C:\Program Files\LLVM\bin\clang.exe" # Or Linux-style path if in WSL/MinGW
$ROCM_PREFIX = "/o/r-st-gfx1151/build/dist/rocm"
$AOTRITON_PREFIX = "/o/aotriton/build/install_dir"
# --- Compiler Flags ---
# -I<path> is a single argument for clang, so this is fine as a single string.
$HIP_CFLAGS = "-I$ROCM_PREFIX/include"
# These are two distinct arguments for clang, so store them as an array.
$AOTRITON_CFLAGS = @(
"-D__HIP_PLATFORM_AMD__",
"-I$AOTRITON_PREFIX/include"
)
# --- Linker Flags and Libraries ---
# These are multiple distinct arguments.
$AOTRITON_LIBS = @(
"-L$AOTRITON_PREFIX/lib",
"-laotriton_v2.lib",
"-L$ROCM_PREFIX/lib",
"-lamdhip64" # Tells linker to find libamdhip64.so or amdhip64.lib etc.
)
# --- Source and Output Files ---
$sourceFile = "attn-fwd-test.cpp"
$outputFile = "attn-fwd-test.exe"
# --- Construct the full argument list for clang ---
# This is the most robust way in PowerShell to build arguments for external commands.
$clangArgs = @()
$clangArgs += $HIP_CFLAGS # Adds the string "-I/path/to/rocm/include" as one argument
$clangArgs += $AOTRITON_CFLAGS # Adds elements of the array as separate arguments
$clangArgs += $sourceFile
$clangArgs += $AOTRITON_LIBS # Adds elements of the array as separate arguments
$clangArgs += "-o"
$clangArgs += $outputFile
# Optional: Add other flags like -Wall, -g, -std=c++17 etc.
# $clangArgs += "-Wall"
# $clangArgs += "-g"
# $clangArgs += "-std=c++17"
# --- Execute Clang ---
Write-Host "Attempting to compile with clang..."
# Display the command that will be executed (for debugging/verification)
# PowerShell's default stringification of an array for Write-Host might not look exactly like the shell command,
# but the splatting operator (@) ensures arguments are passed correctly.
Write-Host "Command: & '$clangExecutable' $($clangArgs -join ' ')" # This join is just for display
try {
# Use the call operator (&) and splatting (@)
& $clangExecutable @clangArgs
if ($LASTEXITCODE -eq 0) {
Write-Host "Compilation successful: '$outputFile' created." -ForegroundColor Green
} else {
Write-Error "Clang compilation failed. Exit code: $LASTEXITCODE"
}
} catch {
Write-Error "An error occurred while trying to execute clang: $($_.Exception.Message)"
}
PROGRAMS := $(basename $(wildcard *.cpp))
ROCM_PREFIX := $(HOME)/therock-upstream-output/build/dist/rocm
HIP_CFLAGS := -I$(ROCM_PREFIX)/include
AOTRITON_CFLAGS := -D__HIP_PLATFORM_AMD__ -I$(HOME)/aotriton-output/build/install_dir/include
AOTRITON_LIBS := -L$(HOME)/aotriton-output/build/install_dir/lib -laotriton_v2 -L$(ROCM_PREFIX)/lib -lamdhip64
%: %.cpp
clang++ $(HIP_CFLAGS) $(AOTRITON_CFLAGS) $< $(AOTRITON_LIBS) -o $@
.PHONY: all
all: $(PROGRAMS)
.PHONY: clean
clean:
rm -f $(PROGRAMS)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment