Skip to content

Instantly share code, notes, and snippets.

@vient
Last active August 13, 2024 00:39
Show Gist options
  • Save vient/15ae1d54fdec78b7785bdccdb39a4458 to your computer and use it in GitHub Desktop.
Save vient/15ae1d54fdec78b7785bdccdb39a4458 to your computer and use it in GitHub Desktop.
AMX test
#include <array>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <immintrin.h>
#include <unistd.h>
#include <sys/syscall.h>
#define ARCH_REQ_XCOMP_PERM 0x1023
#define XFEATURE_XTILECFG 17
#define XFEATURE_XTILEDATA 18
#define XFEATURE_MASK_XTILE ((1 << XFEATURE_XTILECFG) | (1 << XFEATURE_XTILEDATA))
std::array<std::uint32_t, 64 * 1024 * 1024> PixelData;
int main()
{
if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA))
{
printf("\n Failed to enable XFEATURE_XTILEDATA \n\n");
return 1;
}
// {number of rows, column size in bytes}
__tile1024i PixelTile = {16, 4}; // 16rowsx4b (16x1) 16 pixels (16 ints)
__tile1024i MaskTile = {4, 64}; // 4rowx16b (4x16) masks (4 x 16 ints)
__tile1024i SumTile = {4, 4}; // 4rowsx4b (4x1) four RGBA sums (4 ints)
// [R Sum32] [RRRRRRRR...] [ RGBA ]
// [G Sum32] += [GGGGGGGG...] * [ RGBA ]
// [B Sum32] [BBBBBBBB...] [ RGBA ]
// [A Sum32] [AAAAAAAA...] [ RGBA ]
// Sums Masks [ RGBA ]
// [ ... ]
// Pixels
// Generate Mask-Matrix
std::array<std::uint32_t, 4 * 16> MaskData;
for( std::size_t ChannelIndex = 0; ChannelIndex < 4; ++ChannelIndex )
{
for( std::size_t j = 0; j < 16; ++j )
{
// Each row is masking a particular RGBA channel.
// 0: 0x00'00'00'01
// 1: 0x00'00'01'00
// 2: 0x00'01'00'00
// 3: 0x01'00'00'00
MaskData[j + ChannelIndex * 16]
= (uint32_t(1) << (ChannelIndex * 8));
}
}
// Load mask-matrix
// Each row is composed of 16x32-bit integers. 64 bytes per row
__tile_loadd(&MaskTile, MaskData.data(), sizeof(std::uint32_t) * 16);
// Initialize the Sum to 0, 0, 0, 0
__tile_zero(&SumTile);
// Generate a sample RGBA "image" of all the same pixels
PixelData.fill(0xAA'BB'CC'DD); // Fill it with some pixel data
// Process 16 pixels at a time
for( std::size_t i = 0; i < PixelData.size(); i += 16 )
{
// Load 64 bytes of RGBA pixel data, 16 pixels
// Be careful here, each "row" is 4 bytes long, so the stride is 4 bytes
__tile_stream_loadd(&PixelTile, PixelData.data() + i, 4);
// 8-bit dot-product rows of A and columns of B into matrix C of 32-bit
// sums
__tile_dpbuud(&SumTile, MaskTile, PixelTile);
}
// Store vector of sums
std::array<std::uint32_t, 4> SumData;
__tile_stored(SumData.data(), 4, SumTile);
// Print
for( std::size_t i = 0; i < 4; ++i )
{
std::printf("%08X ", SumData[i] / std::uint32_t(PixelData.size()));
}
// 000000DD 000000CC 000000BB 000000AA
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment