Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Created October 24, 2025 00:58
Show Gist options
  • Save CoffeeVampir3/5c5c9b475e2bafb001ea12d3d997098e to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/5c5c9b475e2bafb001ea12d3d997098e to your computer and use it in GitHub Desktop.
AMX CLAUDE.md

AMX (Advanced Matrix Extensions) Programming Guide

Overview

AMX is a SIMD extension for x86-64 that provides hardware-accelerated matrix operations using tile registers. It's designed for high-performance matrix multiplication in AI/ML workloads.

Key Features:

  • 8 tile registers (TMM0-TMM7)
  • Each tile: up to 16 rows × 64 bytes
  • Native support for INT8, BF16, and FP16 operations
  • Significantly faster than AVX-512 for matrix operations

Compile flags:

clang -mamx-tile -mamx-int8 -mamx-bf16 -march=native

Critical Setup: OS Permissions

MOST IMPORTANT: Before using any AMX instructions, you MUST request permission from the OS. Without this, your program will segfault.

#include <sys/syscall.h>
#include <unistd.h>

#define ARCH_REQ_XCOMP_PERM 0x1023
#define XFEATURE_XTILEDATA 18

bool enable_amx() {
    long rc = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);
    return rc == 0;
}

int main() {
    if (!enable_amx()) {
        fprintf(stderr, "Failed to enable AMX\n");
        return 1;
    }
    // Now safe to use AMX instructions
}

Why this is needed: AMX uses extended CPU state that must be saved/restored on context switches. The kernel needs explicit permission to enable this for your process.

Tile Configuration

Before using tiles, configure them with _tile_loadconfig():

struct TileConfig {
    uint8_t palette_id;      // Must be 1 for AMX
    uint8_t start_row;       // Reserved, set to 0
    uint8_t reserved[14];    // Must be zero
    uint16_t colsb[16];      // Column bytes for each tile
    uint8_t rows[16];        // Row count for each tile
};

TileConfig cfg{};
cfg.palette_id = 1;

// Configure tile 0: 16 rows × 64 bytes
cfg.rows[0] = 16;
cfg.colsb[0] = 64;

_tile_loadconfig(&cfg);

Important notes:

  • colsb is in bytes, not elements
  • For INT8: 64 bytes = 64 elements
  • For INT32 output: 64 bytes = 16 elements
  • For BF16: 64 bytes = 32 elements
  • Always call _tile_release() when done

Memory Alignment

Critical: All memory used with AMX must be 64-byte aligned.

// Correct way to allocate
void* ptr = std::aligned_alloc(64, size);

// Or with unique_ptr
auto data = std::unique_ptr<int8_t[], decltype(&std::free)>(
    static_cast<int8_t*>(std::aligned_alloc(64, size)),
    &std::free
);

Why: AMX loads/stores work on 64-byte cache lines. Misaligned memory causes segfaults or undefined behavior.

Core Instructions

Tile Load/Store

// Load tile from memory
// stride is the row stride in BYTES
_tile_loadd(tile_index, ptr, stride);

// Store tile to memory
_tile_stored(tile_index, ptr, stride);

// Zero a tile
_tile_zero(tile_index);

Stride calculation:

  • For row-major matrix with N columns of INT8: stride = N
  • For row-major matrix with N columns of INT32: stride = N * 4
  • Stride must account for element size in bytes

INT8 Matrix Multiplication

// C += A × B (all INT8 input, INT32 accumulator)
_tile_dpbssd(dst, src1, src2);  // signed × signed
_tile_dpbsud(dst, src1, src2);  // signed × unsigned
_tile_dpbusd(dst, src1, src2);  // unsigned × signed
_tile_dpbuud(dst, src1, src2);  // unsigned × unsigned

Matrix dimensions:

  • A (src1): M rows × K columns (K must be multiple of 4)
  • B (src2): K/4 rows × N×4 columns (VNNI layout)
  • C (dst): M rows × N columns (INT32)

VNNI Layout: B must be transposed and packed in groups of 4:

Normal: [b0, b1, b2, b3, ...]
VNNI:   [b0, b4, b8, b12, b1, b5, b9, b13, ...]

BF16 Matrix Multiplication

// C += A × B (BF16 input, FP32 accumulator)
_tile_dpbf16ps(dst, src1, src2);

Matrix dimensions:

  • A (src1): M rows × K columns (K must be multiple of 2)
  • B (src2): K/2 rows × N×2 columns (VNNI layout)
  • C (dst): M rows × N columns (FP32)

Complete Working Example

#include <immintrin.h>
#include <sys/syscall.h>
#include <unistd.h>
#include <cstring>
#include <iostream>

#define ARCH_REQ_XCOMP_PERM 0x1023
#define XFEATURE_XTILEDATA 18

struct TileConfig {
    uint8_t palette_id;
    uint8_t start_row;
    uint8_t reserved[14];
    uint16_t colsb[16];
    uint8_t rows[16];
};

bool enable_amx() {
    return syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA) == 0;
}

void matmul_16x16x64_int8() {
    const int M = 16, N = 16, K = 64;
    
    // Allocate aligned memory
    auto a = static_cast<int8_t*>(std::aligned_alloc(64, M * K));
    auto b = static_cast<int8_t*>(std::aligned_alloc(64, K * N));
    auto c = static_cast<int32_t*>(std::aligned_alloc(64, M * N * 4));
    
    // Initialize
    std::memset(a, 1, M * K);
    std::memset(b, 2, K * N);
    std::memset(c, 0, M * N * 4);
    
    // Configure tiles
    TileConfig cfg{};
    cfg.palette_id = 1;
    cfg.rows[0] = M;     cfg.colsb[0] = K;           // Tile 0: A
    cfg.rows[1] = K/4;   cfg.colsb[1] = N * 4;       // Tile 1: B (VNNI)
    cfg.rows[2] = M;     cfg.colsb[2] = N * 4;       // Tile 2: C
    
    _tile_loadconfig(&cfg);
    
    // Compute: C = A × B
    _tile_zero(2);
    _tile_loadd(0, a, K);
    _tile_loadd(1, b, N * 4);
    _tile_dpbssd(2, 0, 1);
    _tile_stored(2, c, N * 4);
    
    _tile_release();
    
    std::cout << "Result[0] = " << c[0] << "\n";
    
    std::free(a);
    std::free(b);
    std::free(c);
}

int main() {
    if (!enable_amx()) {
        std::cerr << "AMX not available\n";
        return 1;
    }
    matmul_16x16x64_int8();
    return 0;
}

Common Pitfalls

1. Forgetting to Request Permission

Problem: Immediate segfault on first AMX instruction
Solution: Always call the syscall before any AMX operations

2. Incorrect Memory Alignment

Problem: Segfaults or corrupted data
Solution: Use std::aligned_alloc(64, size) or posix_memalign

3. Wrong Stride Values

Problem: Incorrect results or memory corruption
Solution: Stride is in BYTES, not elements. For INT32 matrix with N columns, stride = N * 4

4. Forgetting VNNI Layout for B Matrix

Problem: Incorrect results
Solution: Matrix B must be transposed and packed in groups of 4 (INT8) or 2 (BF16)

5. Not Releasing Tiles

Problem: Resource leaks, potential crashes
Solution: Always call _tile_release() when done, even on error paths

6. Incorrect Tile Dimensions

Problem: Segfaults or wrong results
Solution:

  • Max 16 rows per tile
  • Max 64 bytes per row
  • K dimension must be multiple of VNNI_BLK (4 for INT8, 2 for BF16)

Performance Tips

  1. Minimize tile loads/stores: Keep data in tiles across multiple operations
  2. Proper blocking: Process matrices in tile-sized chunks (16×16 or 16×32)
  3. Memory layout: Ensure both input matrices are aligned and have cache-friendly strides
  4. Zero tiles efficiently: Use _tile_zero() instead of loading zeros from memory
  5. Batch operations: Process multiple tiles before storing results

Tile Register Layout Strategy

Typical allocation for matrix multiplication:

  • Tiles 0-2: Left matrix (A) rows
  • Tiles 3-4: Right matrix (B) blocks
  • Tiles 4-7: Accumulator (C) blocks

Example for 48×32 output:

   B0  B1
A0 C0  C1
A1 C2  C3
A2 C4  C5

Summary Checklist

Before running AMX code:

  • Request OS permission via syscall
  • Allocate 64-byte aligned memory
  • Configure tiles with _tile_loadconfig()
  • Use correct stride values (in bytes)
  • Handle VNNI layout for matrix B
  • Call _tile_release() when done
  • Compile with -mamx-tile -mamx-int8 -mamx-bf16

AMX Instruction Reference

Configuration & Control

_tile_loadconfig(void* config)

Load tile configuration from 64-byte memory structure. Must be called before using tiles.

  • config: Pointer to TileConfig struct (palette_id, rows[], colsb[])
  • Alignment: Must be 64-byte aligned
  • Note: Invalidates all tile data

_tile_release()

Release all tile registers and exit tile mode. Call when finished with AMX operations.

  • Note: Must be called before process termination or context switch
  • Performance: Cheap operation, but forces tile state save

_tile_storeconfig(void* config)

Store current tile configuration to memory.

Data Movement

_tile_loadd(uint8_t tile, const void* base, int64_t stride)

Load tile from memory with specified stride.

  • tile: Tile register index (0-7)
  • base: Source pointer (must be 64-byte aligned)
  • stride: Row stride in bytes (not elements)
  • Loads: Min(configured_rows, memory_available) rows

_tile_stored(uint8_t tile, void* base, int64_t stride)

Store tile to memory with specified stride.

  • tile: Tile register index (0-7)
  • base: Destination pointer (must be 64-byte aligned)
  • stride: Row stride in bytes
  • Stores: Full configured tile dimensions

_tile_zero(uint8_t tile)

Zero all bytes in a tile register. Fast initialization.

  • tile: Tile register index (0-7)
  • Performance: Much faster than loading zeros from memory

_tile_stream_loadd(uint8_t tile, const void* base, int64_t stride)

Non-temporal load (bypasses cache). Use for large, single-use data.

Integer Matrix Operations (INT8)

_tile_dpbssd(uint8_t dst, uint8_t src1, uint8_t src2)

Matrix multiply-accumulate: dst += src1 × src2 (signed × signed → INT32)

  • src1: M×K INT8 matrix (K must be multiple of 4)
  • src2: (K/4)×(N×4) INT8 matrix in VNNI format
  • dst: M×N INT32 accumulator
  • Operation: Each output element is sum of K dot products

_tile_dpbsud(uint8_t dst, uint8_t src1, uint8_t src2)

Matrix multiply-accumulate: dst += src1 × src2 (signed × unsigned → INT32)

  • src1: Signed INT8
  • src2: Unsigned INT8 (VNNI format)
  • Use case: Activations (signed) × weights (unsigned)

_tile_dpbusd(uint8_t dst, uint8_t src1, uint8_t src2)

Matrix multiply-accumulate: dst += src1 × src2 (unsigned × signed → INT32)

  • src1: Unsigned INT8
  • src2: Signed INT8 (VNNI format)
  • Use case: Activations (unsigned) × weights (signed)

_tile_dpbuud(uint8_t dst, uint8_t src1, uint8_t src2)

Matrix multiply-accumulate: dst += src1 × src2 (unsigned × unsigned → INT32)

  • src1: Unsigned INT8
  • src2: Unsigned INT8 (VNNI format)
  • Use case: Both inputs non-negative

BF16 Matrix Operations

_tile_dpbf16ps(uint8_t dst, uint8_t src1, uint8_t src2)

Matrix multiply-accumulate: dst += src1 × src2 (BF16 × BF16 → FP32)

  • src1: M×K BF16 matrix (K must be multiple of 2)
  • src2: (K/2)×(N×2) BF16 matrix in VNNI format
  • dst: M×N FP32 accumulator
  • Precision: ~7 decimal digits (same as FP32 for most ML workloads)

Key Constraints

Instruction Max Rows Max Cols (bytes) K Constraint Notes
dpbssd/dpbsud/dpbusd/dpbuud 16 64 K % 4 == 0 INT32 output: 16 elements/row
dpbf16ps 16 64 K % 2 == 0 FP32 output: 16 elements/row
dpfp16ps 16 64 K % 2 == 0 FP32 output: 16 elements/row
loadd/stored 16 64 - Alignment: 64 bytes required

Tile Register Allocation

8 tile registers available (TMM0-TMM7):

  • Tiles are 2D registers, dimensions set by configuration
  • No register pressure in typical usage
  • Common pattern: 2-3 tiles for A, 1-2 for B, 3-4 for C accumulators

Typical Operation Sequence

1. _tile_loadconfig()      // Set dimensions
2. _tile_zero(dst)         // Clear accumulator
3. _tile_loadd(src1, ...)  // Load A
4. _tile_loadd(src2, ...)  // Load B
5. _tile_dpbssd(dst, ...)  // Compute C += A×B
6. _tile_stored(dst, ...)  // Write result
7. _tile_release()         // Clean up
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment