AMX is a SIMD extension for x86-64 that provides hardware-accelerated matrix operations using tile registers. It's designed for high-performance matrix multiplication in AI/ML workloads.
Key Features:
- 8 tile registers (TMM0-TMM7)
- Each tile: up to 16 rows × 64 bytes
- Native support for INT8, BF16, and FP16 operations
- Significantly faster than AVX-512 for matrix operations
Compile flags:
clang -mamx-tile -mamx-int8 -mamx-bf16 -march=nativeMOST IMPORTANT: Before using any AMX instructions, you MUST request permission from the OS. Without this, your program will segfault.
#include <sys/syscall.h>
#include <unistd.h>
#define ARCH_REQ_XCOMP_PERM 0x1023
#define XFEATURE_XTILEDATA 18
bool enable_amx() {
long rc = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);
return rc == 0;
}
int main() {
if (!enable_amx()) {
fprintf(stderr, "Failed to enable AMX\n");
return 1;
}
// Now safe to use AMX instructions
}Why this is needed: AMX uses extended CPU state that must be saved/restored on context switches. The kernel needs explicit permission to enable this for your process.
Before using tiles, configure them with _tile_loadconfig():
struct TileConfig {
uint8_t palette_id; // Must be 1 for AMX
uint8_t start_row; // Reserved, set to 0
uint8_t reserved[14]; // Must be zero
uint16_t colsb[16]; // Column bytes for each tile
uint8_t rows[16]; // Row count for each tile
};
TileConfig cfg{};
cfg.palette_id = 1;
// Configure tile 0: 16 rows × 64 bytes
cfg.rows[0] = 16;
cfg.colsb[0] = 64;
_tile_loadconfig(&cfg);Important notes:
colsbis in bytes, not elements- For INT8: 64 bytes = 64 elements
- For INT32 output: 64 bytes = 16 elements
- For BF16: 64 bytes = 32 elements
- Always call
_tile_release()when done
Critical: All memory used with AMX must be 64-byte aligned.
// Correct way to allocate
void* ptr = std::aligned_alloc(64, size);
// Or with unique_ptr
auto data = std::unique_ptr<int8_t[], decltype(&std::free)>(
static_cast<int8_t*>(std::aligned_alloc(64, size)),
&std::free
);Why: AMX loads/stores work on 64-byte cache lines. Misaligned memory causes segfaults or undefined behavior.
// Load tile from memory
// stride is the row stride in BYTES
_tile_loadd(tile_index, ptr, stride);
// Store tile to memory
_tile_stored(tile_index, ptr, stride);
// Zero a tile
_tile_zero(tile_index);Stride calculation:
- For row-major matrix with N columns of INT8: stride = N
- For row-major matrix with N columns of INT32: stride = N * 4
- Stride must account for element size in bytes
// C += A × B (all INT8 input, INT32 accumulator)
_tile_dpbssd(dst, src1, src2); // signed × signed
_tile_dpbsud(dst, src1, src2); // signed × unsigned
_tile_dpbusd(dst, src1, src2); // unsigned × signed
_tile_dpbuud(dst, src1, src2); // unsigned × unsignedMatrix dimensions:
- A (src1): M rows × K columns (K must be multiple of 4)
- B (src2): K/4 rows × N×4 columns (VNNI layout)
- C (dst): M rows × N columns (INT32)
VNNI Layout: B must be transposed and packed in groups of 4:
Normal: [b0, b1, b2, b3, ...]
VNNI: [b0, b4, b8, b12, b1, b5, b9, b13, ...]
// C += A × B (BF16 input, FP32 accumulator)
_tile_dpbf16ps(dst, src1, src2);Matrix dimensions:
- A (src1): M rows × K columns (K must be multiple of 2)
- B (src2): K/2 rows × N×2 columns (VNNI layout)
- C (dst): M rows × N columns (FP32)
#include <immintrin.h>
#include <sys/syscall.h>
#include <unistd.h>
#include <cstring>
#include <iostream>
#define ARCH_REQ_XCOMP_PERM 0x1023
#define XFEATURE_XTILEDATA 18
struct TileConfig {
uint8_t palette_id;
uint8_t start_row;
uint8_t reserved[14];
uint16_t colsb[16];
uint8_t rows[16];
};
bool enable_amx() {
return syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA) == 0;
}
void matmul_16x16x64_int8() {
const int M = 16, N = 16, K = 64;
// Allocate aligned memory
auto a = static_cast<int8_t*>(std::aligned_alloc(64, M * K));
auto b = static_cast<int8_t*>(std::aligned_alloc(64, K * N));
auto c = static_cast<int32_t*>(std::aligned_alloc(64, M * N * 4));
// Initialize
std::memset(a, 1, M * K);
std::memset(b, 2, K * N);
std::memset(c, 0, M * N * 4);
// Configure tiles
TileConfig cfg{};
cfg.palette_id = 1;
cfg.rows[0] = M; cfg.colsb[0] = K; // Tile 0: A
cfg.rows[1] = K/4; cfg.colsb[1] = N * 4; // Tile 1: B (VNNI)
cfg.rows[2] = M; cfg.colsb[2] = N * 4; // Tile 2: C
_tile_loadconfig(&cfg);
// Compute: C = A × B
_tile_zero(2);
_tile_loadd(0, a, K);
_tile_loadd(1, b, N * 4);
_tile_dpbssd(2, 0, 1);
_tile_stored(2, c, N * 4);
_tile_release();
std::cout << "Result[0] = " << c[0] << "\n";
std::free(a);
std::free(b);
std::free(c);
}
int main() {
if (!enable_amx()) {
std::cerr << "AMX not available\n";
return 1;
}
matmul_16x16x64_int8();
return 0;
}Problem: Immediate segfault on first AMX instruction
Solution: Always call the syscall before any AMX operations
Problem: Segfaults or corrupted data
Solution: Use std::aligned_alloc(64, size) or posix_memalign
Problem: Incorrect results or memory corruption
Solution: Stride is in BYTES, not elements. For INT32 matrix with N columns, stride = N * 4
Problem: Incorrect results
Solution: Matrix B must be transposed and packed in groups of 4 (INT8) or 2 (BF16)
Problem: Resource leaks, potential crashes
Solution: Always call _tile_release() when done, even on error paths
Problem: Segfaults or wrong results
Solution:
- Max 16 rows per tile
- Max 64 bytes per row
- K dimension must be multiple of VNNI_BLK (4 for INT8, 2 for BF16)
- Minimize tile loads/stores: Keep data in tiles across multiple operations
- Proper blocking: Process matrices in tile-sized chunks (16×16 or 16×32)
- Memory layout: Ensure both input matrices are aligned and have cache-friendly strides
- Zero tiles efficiently: Use
_tile_zero()instead of loading zeros from memory - Batch operations: Process multiple tiles before storing results
Typical allocation for matrix multiplication:
- Tiles 0-2: Left matrix (A) rows
- Tiles 3-4: Right matrix (B) blocks
- Tiles 4-7: Accumulator (C) blocks
Example for 48×32 output:
B0 B1
A0 C0 C1
A1 C2 C3
A2 C4 C5
Before running AMX code:
- Request OS permission via syscall
- Allocate 64-byte aligned memory
- Configure tiles with
_tile_loadconfig() - Use correct stride values (in bytes)
- Handle VNNI layout for matrix B
- Call
_tile_release()when done - Compile with
-mamx-tile -mamx-int8 -mamx-bf16
Load tile configuration from 64-byte memory structure. Must be called before using tiles.
- config: Pointer to TileConfig struct (palette_id, rows[], colsb[])
- Alignment: Must be 64-byte aligned
- Note: Invalidates all tile data
Release all tile registers and exit tile mode. Call when finished with AMX operations.
- Note: Must be called before process termination or context switch
- Performance: Cheap operation, but forces tile state save
Store current tile configuration to memory.
Load tile from memory with specified stride.
- tile: Tile register index (0-7)
- base: Source pointer (must be 64-byte aligned)
- stride: Row stride in bytes (not elements)
- Loads: Min(configured_rows, memory_available) rows
Store tile to memory with specified stride.
- tile: Tile register index (0-7)
- base: Destination pointer (must be 64-byte aligned)
- stride: Row stride in bytes
- Stores: Full configured tile dimensions
Zero all bytes in a tile register. Fast initialization.
- tile: Tile register index (0-7)
- Performance: Much faster than loading zeros from memory
Non-temporal load (bypasses cache). Use for large, single-use data.
Matrix multiply-accumulate: dst += src1 × src2 (signed × signed → INT32)
- src1: M×K INT8 matrix (K must be multiple of 4)
- src2: (K/4)×(N×4) INT8 matrix in VNNI format
- dst: M×N INT32 accumulator
- Operation: Each output element is sum of K dot products
Matrix multiply-accumulate: dst += src1 × src2 (signed × unsigned → INT32)
- src1: Signed INT8
- src2: Unsigned INT8 (VNNI format)
- Use case: Activations (signed) × weights (unsigned)
Matrix multiply-accumulate: dst += src1 × src2 (unsigned × signed → INT32)
- src1: Unsigned INT8
- src2: Signed INT8 (VNNI format)
- Use case: Activations (unsigned) × weights (signed)
Matrix multiply-accumulate: dst += src1 × src2 (unsigned × unsigned → INT32)
- src1: Unsigned INT8
- src2: Unsigned INT8 (VNNI format)
- Use case: Both inputs non-negative
Matrix multiply-accumulate: dst += src1 × src2 (BF16 × BF16 → FP32)
- src1: M×K BF16 matrix (K must be multiple of 2)
- src2: (K/2)×(N×2) BF16 matrix in VNNI format
- dst: M×N FP32 accumulator
- Precision: ~7 decimal digits (same as FP32 for most ML workloads)
| Instruction | Max Rows | Max Cols (bytes) | K Constraint | Notes |
|---|---|---|---|---|
| dpbssd/dpbsud/dpbusd/dpbuud | 16 | 64 | K % 4 == 0 | INT32 output: 16 elements/row |
| dpbf16ps | 16 | 64 | K % 2 == 0 | FP32 output: 16 elements/row |
| dpfp16ps | 16 | 64 | K % 2 == 0 | FP32 output: 16 elements/row |
| loadd/stored | 16 | 64 | - | Alignment: 64 bytes required |
8 tile registers available (TMM0-TMM7):
- Tiles are 2D registers, dimensions set by configuration
- No register pressure in typical usage
- Common pattern: 2-3 tiles for A, 1-2 for B, 3-4 for C accumulators
1. _tile_loadconfig() // Set dimensions
2. _tile_zero(dst) // Clear accumulator
3. _tile_loadd(src1, ...) // Load A
4. _tile_loadd(src2, ...) // Load B
5. _tile_dpbssd(dst, ...) // Compute C += A×B
6. _tile_stored(dst, ...) // Write result
7. _tile_release() // Clean up