Skip to content

Instantly share code, notes, and snippets.

@maderix
Last active May 14, 2026 08:39
Show Gist options
  • Select an option

  • Save maderix/16d821902bfcc3debb30de4e5c735c66 to your computer and use it in GitHub Desktop.

Select an option

Save maderix/16d821902bfcc3debb30de4e5c735c66 to your computer and use it in GitHub Desktop.
ANE Prefill Pipeline — Qwen 3.5 9B on Apple M4 (221 tok/s, 5.05 TFLOPS, FP16)
// Qwen 3.5 9B Prefill Pipeline — ANE (Apple Neural Engine)
//
// 221 tok/s prefill, 5.05 TFLOPS ANE, FP16 weights, no quantization.
// Apple M4 (10-core: 4P+6E), 24 GB, macOS 15.
// Single self-contained file. No external dependencies beyond Apple frameworks.
//
// Architecture (32 layers, 24 DeltaNet + 8 Attention):
// DeltaNet: QKV proj (ANE) → recurrence (CPU/OpenMP) → out proj (ANE) → FFN (ANE)
// Attention: Q/K/V proj (ANE) → causal SDPA (CPU/AMX) → O proj (ANE) → FFN (ANE)
//
// ANE kernels (4 MIL graphs, compiled once, reused across layers):
// proj_large: conv1x1 [8192, 4096] — DeltaNet QKV, Attn Q
// proj_medium: conv1x1 [4096, 4096] — gate_z, out proj, O proj
// proj_small: conv1x1 [1024, 4096] — K, V proj
// fused_ffn: W1+SiLU+W3+gate+W2(K-tiled×3) — all 32 layers, single eval
//
// Matmul → conv1x1 reshape (ANE has no matmul, but conv is 7+ TFLOPS):
// Input A[S,K] → [1, K, 1, S] (channels=K, spatial=S)
// Weight B[N,K] → [N, K, 1, 1] (conv filter: OC=N, IC=K)
// Output [1, N, 1, S] → C[S,N]
//
// Results (S=256):
// Total: 1158 ms, 221 tok/s
// ANE projections: 243 ms (21%) — 5.05 TFLOPS
// ANE fused FFN: 457 ms (40%) — 5.48 TFLOPS
// CPU DeltaNet: 390 ms (34%) — OpenMP parallel 32 heads
// CPU attention: 18 ms (2%) — AMX/BLAS causal SDPA
// CPU RMSNorm+misc: 48 ms (4%)
//
// Build:
// xcrun clang -O2 -fobjc-arc -DACCELERATE_NEW_LAPACK \
// -Xclang -fopenmp -lomp \
// -framework Foundation -framework IOSurface -framework Accelerate -lm \
// -o /tmp/ane_prefill ane_prefill_pipeline.m
// /tmp/ane_prefill [seq_len]
#import <Foundation/Foundation.h>
#import <Accelerate/Accelerate.h>
#import <IOSurface/IOSurface.h>
#import <objc/runtime.h>
#import <objc/message.h>
#include <dlfcn.h>
#include <math.h>
#include <mach/mach_time.h>
#include <string.h>
#include <stdlib.h>
#include <omp.h>
// ═══════════════════════════════════════════════════════════════════
// ANE Bridge — inline private API wrapper
// ═══════════════════════════════════════════════════════════════════
static Class g_ANEDesc = nil, g_ANEInMem = nil, g_ANEReq = nil, g_ANEIO = nil;
static bool g_ane_initialized = false;
typedef struct ANEKernelHandle {
id model;
IOSurfaceRef *ioInputs;
IOSurfaceRef *ioOutputs;
id request;
NSString *tmpDir;
int nInputs, nOutputs;
size_t *inputBytes, *outputBytes;
} ANEKernelHandle;
static int ane_bridge_init(void) {
if (g_ane_initialized) return 0;
void *h = dlopen("/System/Library/PrivateFrameworks/AppleNeuralEngine.framework/AppleNeuralEngine", RTLD_NOW);
if (!h) { fprintf(stderr, "ane_bridge: Failed to load AppleNeuralEngine.framework\n"); return -1; }
g_ANEDesc = NSClassFromString(@"_ANEInMemoryModelDescriptor");
g_ANEInMem = NSClassFromString(@"_ANEInMemoryModel");
g_ANEReq = NSClassFromString(@"_ANERequest");
g_ANEIO = NSClassFromString(@"_ANEIOSurfaceObject");
if (!g_ANEDesc || !g_ANEInMem || !g_ANEReq || !g_ANEIO) {
fprintf(stderr, "ane_bridge: Failed to resolve ANE private classes\n"); return -1;
}
g_ane_initialized = true;
return 0;
}
static IOSurfaceRef create_surface(size_t bytes) {
return IOSurfaceCreate((__bridge CFDictionaryRef)@{
(id)kIOSurfaceWidth: @(bytes), (id)kIOSurfaceHeight: @1,
(id)kIOSurfaceBytesPerElement: @1, (id)kIOSurfaceBytesPerRow: @(bytes),
(id)kIOSurfaceAllocSize: @(bytes), (id)kIOSurfacePixelFormat: @0
});
}
static ANEKernelHandle *ane_compile(const char *mil_text, size_t mil_len,
const uint8_t *weight_data, size_t weight_len,
int n_inputs, const size_t *input_sizes,
int n_outputs, const size_t *output_sizes)
{
@autoreleasepool {
if (!g_ane_initialized) return NULL;
NSData *milData = [NSData dataWithBytes:mil_text length:mil_len];
NSError *e = nil;
NSMutableDictionary *wdict = [NSMutableDictionary dictionary];
if (weight_data && weight_len > 0) {
NSData *wd = [NSData dataWithBytes:weight_data length:weight_len];
wdict[@"@model_path/weights/weight.bin"] = @{@"offset": @0, @"data": wd};
}
id desc = ((id(*)(Class,SEL,id,id,id))objc_msgSend)(
g_ANEDesc, @selector(modelWithMILText:weights:optionsPlist:),
milData, wdict.count > 0 ? wdict : @{}, nil);
if (!desc) { fprintf(stderr, "ane: modelWithMILText failed\n"); return NULL; }
id mdl = ((id(*)(Class,SEL,id))objc_msgSend)(
g_ANEInMem, @selector(inMemoryModelWithDescriptor:), desc);
if (!mdl) { fprintf(stderr, "ane: inMemoryModelWithDescriptor failed\n"); return NULL; }
id hx = ((id(*)(id,SEL))objc_msgSend)(mdl, @selector(hexStringIdentifier));
NSString *td = [NSTemporaryDirectory() stringByAppendingPathComponent:hx];
NSFileManager *fm = [NSFileManager defaultManager];
[fm createDirectoryAtPath:[td stringByAppendingPathComponent:@"weights"]
withIntermediateDirectories:YES attributes:nil error:nil];
[milData writeToFile:[td stringByAppendingPathComponent:@"model.mil"] atomically:YES];
if (weight_data && weight_len > 0) {
NSData *wd = [NSData dataWithBytes:weight_data length:weight_len];
[wd writeToFile:[td stringByAppendingPathComponent:@"weights/weight.bin"] atomically:YES];
}
if (!((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(
mdl, @selector(compileWithQoS:options:error:), 21, @{}, &e)) {
fprintf(stderr, "ane: compile failed: %s\n", e ? [[e description] UTF8String] : "?");
[fm removeItemAtPath:td error:nil]; return NULL;
}
BOOL loaded = ((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(
mdl, @selector(loadWithQoS:options:error:), 21, @{}, &e);
if (!loaded) {
usleep(100000);
e = nil;
loaded = ((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(
mdl, @selector(loadWithQoS:options:error:), 21, @{}, &e);
}
if (!loaded) {
fprintf(stderr, "ane: load failed: %s\n", e ? [[e description] UTF8String] : "?");
[fm removeItemAtPath:td error:nil]; return NULL;
}
ANEKernelHandle *k = (ANEKernelHandle *)calloc(1, sizeof(ANEKernelHandle));
k->model = mdl; k->tmpDir = td;
k->nInputs = n_inputs; k->nOutputs = n_outputs;
k->inputBytes = (size_t *)malloc(n_inputs * sizeof(size_t));
k->outputBytes = (size_t *)malloc(n_outputs * sizeof(size_t));
memcpy(k->inputBytes, input_sizes, n_inputs * sizeof(size_t));
memcpy(k->outputBytes, output_sizes, n_outputs * sizeof(size_t));
k->ioInputs = (IOSurfaceRef *)malloc(n_inputs * sizeof(IOSurfaceRef));
k->ioOutputs = (IOSurfaceRef *)malloc(n_outputs * sizeof(IOSurfaceRef));
for (int i = 0; i < n_inputs; i++) k->ioInputs[i] = create_surface(input_sizes[i]);
for (int i = 0; i < n_outputs; i++) k->ioOutputs[i] = create_surface(output_sizes[i]);
NSMutableArray *wIns = [NSMutableArray arrayWithCapacity:n_inputs];
NSMutableArray *iIdx = [NSMutableArray arrayWithCapacity:n_inputs];
for (int i = 0; i < n_inputs; i++) {
[wIns addObject:((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(
g_ANEIO, @selector(objectWithIOSurface:), k->ioInputs[i])];
[iIdx addObject:@(i)];
}
NSMutableArray *wOuts = [NSMutableArray arrayWithCapacity:n_outputs];
NSMutableArray *oIdx = [NSMutableArray arrayWithCapacity:n_outputs];
for (int i = 0; i < n_outputs; i++) {
[wOuts addObject:((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(
g_ANEIO, @selector(objectWithIOSurface:), k->ioOutputs[i])];
[oIdx addObject:@(i)];
}
k->request = ((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(
g_ANEReq,
@selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:),
wIns, iIdx, wOuts, oIdx, nil, nil, @0);
return k;
}
}
static bool ane_eval(ANEKernelHandle *k) {
@autoreleasepool {
if (!k || !k->model) return false;
NSError *e = nil;
return ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(
k->model, @selector(evaluateWithQoS:options:request:error:), 21, @{}, k->request, &e);
}
}
static void ane_write_input(ANEKernelHandle *k, int idx, const void *data, size_t bytes) {
IOSurfaceLock(k->ioInputs[idx], 0, NULL);
memcpy(IOSurfaceGetBaseAddress(k->ioInputs[idx]), data, bytes);
IOSurfaceUnlock(k->ioInputs[idx], 0, NULL);
}
static void ane_read_output(ANEKernelHandle *k, int idx, void *data, size_t bytes) {
IOSurfaceLock(k->ioOutputs[idx], kIOSurfaceLockReadOnly, NULL);
memcpy(data, IOSurfaceGetBaseAddress(k->ioOutputs[idx]), bytes);
IOSurfaceUnlock(k->ioOutputs[idx], kIOSurfaceLockReadOnly, NULL);
}
static void ane_free(ANEKernelHandle *k) {
@autoreleasepool {
if (!k) return;
NSError *e = nil;
if (k->model)
((BOOL(*)(id,SEL,unsigned int,NSError**))objc_msgSend)(
k->model, @selector(unloadWithQoS:error:), 21, &e);
for (int i = 0; i < k->nInputs; i++) if (k->ioInputs[i]) CFRelease(k->ioInputs[i]);
for (int i = 0; i < k->nOutputs; i++) if (k->ioOutputs[i]) CFRelease(k->ioOutputs[i]);
if (k->tmpDir) [[NSFileManager defaultManager] removeItemAtPath:k->tmpDir error:nil];
free(k->ioInputs); free(k->ioOutputs);
free(k->inputBytes); free(k->outputBytes);
k->model = nil; k->request = nil; k->tmpDir = nil;
free(k);
}
}
// ═══════════════════════════════════════════════════════════════════
// Timing
// ═══════════════════════════════════════════════════════════════════
static mach_timebase_info_data_t tb;
static double ticks_to_ms(uint64_t t) {
return (double)t * tb.numer / tb.denom / 1e6;
}
// ═══════════════════════════════════════════════════════════════════
// MIL Generation — matmul as conv1x1
//
// ANE executes MIL (Model Intermediate Language) graphs compiled by
// Apple's private ANE compiler. Matmul maps to conv with 1×1 filters:
// Input A[S,K] → [1, K, 1, S] (batch=1, channels=K, H=1, W=S)
// Weight B[N,K] → [N, K, 1, 1] (OC=N, IC=K, kH=1, kW=1)
// Output → [1, N, 1, S] → C[S,N]
// ═══════════════════════════════════════════════════════════════════
#define MIL_HDR \
"program(1.3)\n" \
"[buildInfo = dict<string, string>({{\"coremlc-component-MIL\", \"3510.2.1\"}, " \
"{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, " \
"{\"coremltools-version\", \"9.0\"}})]\n{\n"
#define CONV_CONST \
" string pt = const()[name = string(\"pt\"), val = string(\"valid\")];\n" \
" tensor<int32, [2]> st = const()[name = string(\"st\"), " \
"val = tensor<int32, [2]>([1, 1])];\n" \
" tensor<int32, [4]> pd = const()[name = string(\"pd\"), " \
"val = tensor<int32, [4]>([0, 0, 0, 0])];\n" \
" tensor<int32, [2]> dl = const()[name = string(\"dl\"), " \
"val = tensor<int32, [2]>([1, 1])];\n" \
" int32 gr = const()[name = string(\"gr\"), val = int32(1)];\n"
static uint8_t *make_blob(int n_chunks, int *elems_per_chunk) {
size_t total = 64;
for (int i = 0; i < n_chunks; i++)
total += 64 + (size_t)elems_per_chunk[i] * 2;
uint8_t *blob = calloc(total, 1);
blob[0] = 1; blob[4] = 2;
size_t off = 64;
srand(42);
for (int i = 0; i < n_chunks; i++) {
uint8_t *chunk = blob + off;
chunk[0]=0xEF; chunk[1]=0xBE; chunk[2]=0xAD; chunk[3]=0xDE;
chunk[4] = 1;
*(uint32_t*)(chunk + 8) = (uint32_t)(elems_per_chunk[i] * 2);
_Float16 *fp = (_Float16*)(chunk + 64);
for (int j = 0; j < elems_per_chunk[i] && j < 2000; j++)
fp[j] = (_Float16)(((float)(rand()%1000)-500)*0.002f);
off += 64 + (size_t)elems_per_chunk[i] * 2;
}
return blob;
}
static size_t blob_total(int n_chunks, int *elems_per_chunk) {
size_t t = 64;
for (int i = 0; i < n_chunks; i++)
t += 64 + (size_t)elems_per_chunk[i] * 2;
return t;
}
// ─── Single projection: out[S,OC] = in[S,IC] × W[OC,IC]^T ───
//
// MIL program:
// func main(tensor<fp16, [1, IC, 1, S]> x) {
// W = const<fp16, [OC, IC, 1, 1]>(BLOBFILE(...));
// y = conv(weight=W, x=x); // → [1, OC, 1, S]
// } -> (y);
//
static ANEKernelHandle *compile_proj(int oc, int ic, int sp) {
int elems = oc * ic;
size_t sz = 4096;
char *mil = malloc(sz);
size_t mil_len = snprintf(mil, sz,
MIL_HDR
" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n"
CONV_CONST
" tensor<fp16, [%d, %d, 1, 1]> W = const()[name = string(\"W\"), "
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE("
"path = string(\"@model_path/weights/weight.bin\"), "
"offset = uint64(64)))];\n"
" tensor<fp16, [1, %d, 1, %d]> y = conv(dilations = dl, groups = gr, "
"pad = pd, pad_type = pt, strides = st, weight = W, x = x)"
"[name = string(\"c\")];\n"
" } -> (y);\n}\n",
ic, sp, oc, ic, oc, ic, oc, sp);
int chunks[] = {elems};
uint8_t *blob = make_blob(1, chunks);
size_t blen = blob_total(1, chunks);
size_t isz[] = {(size_t)ic * sp * 2};
size_t osz[] = {(size_t)oc * sp * 2};
ANEKernelHandle *k = ane_compile(mil, mil_len, blob, blen, 1, isz, 1, osz);
free(mil); free(blob);
if (!k) fprintf(stderr, "FAILED to compile proj [%d, %d] sp=%d\n", oc, ic, sp);
return k;
}
// ─── Fused FFN: W1+SiLU+W3+gate+W2(K-tiled×3) ───
//
// Single MIL graph, one ANE eval per layer:
// h1 = conv(W1, x) → [1, FFN, 1, S] gate projection
// sig = sigmoid(h1)
// silu = mul(h1, sig) SiLU activation
// h3 = conv(W3, x) → [1, FFN, 1, S] up projection
// gate = mul(silu, h3) gated activation
// d0 = conv(W2_0, gate[0:4096]) ┐
// d1 = conv(W2_1, gate[4096:8192]) ├ K-tiled W2 (12288→3×4096)
// d2 = conv(W2_2, gate[8192:12288])┘
// y = d0 + d1 + d2 → [1, DIM, 1, S] down projection
//
// K-tiling splits W2[4096,12288] into 3×W2[4096,4096] so each fits ANE SRAM.
//
static ANEKernelHandle *compile_fused_ffn(int dim, int ffn, int sp) {
int w1_elems = ffn * dim;
int w3_elems = ffn * dim;
int w2t_elems = dim * (ffn/3);
size_t w1_cs = 64 + (size_t)w1_elems * 2;
size_t w3_cs = 64 + (size_t)w3_elems * 2;
size_t w2t_cs = 64 + (size_t)w2t_elems * 2;
size_t w1_off = 64;
size_t w3_off = w1_off + w1_cs;
int ic_per = ffn / 3;
size_t sz = 16384;
char *mil = malloc(sz);
char *p = mil;
int rem = (int)sz, n;
n = snprintf(p, rem,
MIL_HDR
" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n"
CONV_CONST,
dim, sp);
p += n; rem -= n;
n = snprintf(p, rem,
" tensor<fp16, [%d, %d, 1, 1]> W1 = const()[name = string(\"W1\"), "
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE("
"path = string(\"@model_path/weights/weight.bin\"), "
"offset = uint64(%lu)))];\n",
ffn, dim, ffn, dim, (unsigned long)w1_off);
p += n; rem -= n;
n = snprintf(p, rem,
" tensor<fp16, [%d, %d, 1, 1]> W3 = const()[name = string(\"W3\"), "
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE("
"path = string(\"@model_path/weights/weight.bin\"), "
"offset = uint64(%lu)))];\n",
ffn, dim, ffn, dim, (unsigned long)w3_off);
p += n; rem -= n;
n = snprintf(p, rem,
" tensor<fp16, [1, %d, 1, %d]> h1 = conv(dilations = dl, groups = gr, "
"pad = pd, pad_type = pt, strides = st, weight = W1, x = x)"
"[name = string(\"h1\")];\n"
" tensor<fp16, [1, %d, 1, %d]> sig1 = sigmoid(x = h1)"
"[name = string(\"sig1\")];\n"
" tensor<fp16, [1, %d, 1, %d]> silu = mul(x = h1, y = sig1)"
"[name = string(\"silu\")];\n"
" tensor<fp16, [1, %d, 1, %d]> h3 = conv(dilations = dl, groups = gr, "
"pad = pd, pad_type = pt, strides = st, weight = W3, x = x)"
"[name = string(\"h3\")];\n"
" tensor<fp16, [1, %d, 1, %d]> gate = mul(x = silu, y = h3)"
"[name = string(\"gate\")];\n",
ffn, sp, ffn, sp, ffn, sp, ffn, sp, ffn, sp);
p += n; rem -= n;
for (int t = 0; t < 3; t++) {
int begin_c = t * ic_per;
size_t w2t_off = w3_off + w3_cs + t * w2t_cs;
n = snprintf(p, rem,
" tensor<int32, [4]> gb%d = const()[name = string(\"gb%d\"), "
"val = tensor<int32, [4]>([0, %d, 0, 0])];\n"
" tensor<int32, [4]> gsz%d = const()[name = string(\"gsz%d\"), "
"val = tensor<int32, [4]>([1, %d, 1, %d])];\n"
" tensor<fp16, [1, %d, 1, %d]> g%d = slice_by_size(begin = gb%d, "
"size = gsz%d, x = gate)[name = string(\"g%d\")];\n"
" tensor<fp16, [%d, %d, 1, 1]> W2_%d = const()[name = string(\"W2_%d\"), "
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE("
"path = string(\"@model_path/weights/weight.bin\"), "
"offset = uint64(%lu)))];\n"
" tensor<fp16, [1, %d, 1, %d]> d%d = conv(dilations = dl, groups = gr, "
"pad = pd, pad_type = pt, strides = st, weight = W2_%d, x = g%d)"
"[name = string(\"d%d\")];\n",
t, t, begin_c,
t, t, ic_per, sp,
ic_per, sp, t, t, t, t,
dim, ic_per, t, t, dim, ic_per, (unsigned long)w2t_off,
dim, sp, t, t, t, t);
p += n; rem -= n;
}
n = snprintf(p, rem,
" tensor<fp16, [1, %d, 1, %d]> a01 = add(x = d0, y = d1)"
"[name = string(\"a01\")];\n"
" tensor<fp16, [1, %d, 1, %d]> y = add(x = a01, y = d2)"
"[name = string(\"sum\")];\n"
" } -> (y);\n}\n",
dim, sp, dim, sp);
p += n;
size_t mil_len = p - mil;
int chunks[] = {w1_elems, w3_elems, w2t_elems, w2t_elems, w2t_elems};
size_t blen = 64;
for (int i = 0; i < 5; i++) blen += 64 + (size_t)chunks[i] * 2;
uint8_t *blob = calloc(blen, 1);
blob[0] = 1; blob[4] = 2;
size_t off = 64;
for (int i = 0; i < 5; i++) {
uint8_t *chunk = blob + off;
chunk[0]=0xEF; chunk[1]=0xBE; chunk[2]=0xAD; chunk[3]=0xDE;
chunk[4] = 1;
*(uint32_t*)(chunk + 8) = (uint32_t)(chunks[i] * 2);
_Float16 *fp = (_Float16*)(chunk + 64);
for (int j = 0; j < chunks[i] && j < 2000; j++)
fp[j] = (_Float16)(((float)(rand()%1000)-500)*0.002f);
off += 64 + (size_t)chunks[i] * 2;
}
size_t isz[] = {(size_t)dim * sp * 2};
size_t osz[] = {(size_t)dim * sp * 2};
ANEKernelHandle *k = ane_compile(mil, mil_len, blob, blen, 1, isz, 1, osz);
free(mil); free(blob);
if (!k) fprintf(stderr, "FAILED to compile fused_ffn sp=%d\n", sp);
return k;
}
// ═══════════════════════════════════════════════════════════════════
// CPU Ops (RMSNorm, DeltaNet recurrence, attention, residual)
// ═══════════════════════════════════════════════════════════════════
static void cpu_rmsnorm(float *out, const float *inp, const float *w, int dim, int S) {
for (int s = 0; s < S; s++) {
const float *x = inp + s * dim;
float *o = out + s * dim;
float ss = 0;
for (int i = 0; i < dim; i++) ss += x[i] * x[i];
ss = 1.0f / sqrtf(ss / dim + 1e-6f);
for (int i = 0; i < dim; i++) o[i] = x[i] * ss * w[i];
}
}
static void cpu_residual_add(float *hidden, const float *proj, int dim, int S) {
vDSP_vadd(hidden, 1, proj, 1, hidden, 1, dim * S);
}
// DeltaNet recurrence — per-head state update via BLAS (AMX)
static void deltanet_recurrence(
float *qkv_buf, float *o_buf, float *state, float *tmp_buf,
int S, int H, int D)
{
int HD = H * D;
int DD = D * D;
for (int s = 0; s < S; s++) {
float *q = qkv_buf + s * 8192;
float *k = q + HD;
float *v = k + HD;
float *o = o_buf + s * HD;
#pragma omp parallel for schedule(static)
for (int h = 0; h < H; h++) {
float *st = state + h * DD;
float *qh = q + h * D;
float *kh = k + h * D;
float *vh = v + h * D;
float *th = tmp_buf + h * D;
cblas_sscal(DD, 0.99f, st, 1);
float k_norm = cblas_snrm2(D, kh, 1);
cblas_sscal(D, 1.0f / (k_norm + 1e-8f), kh, 1);
cblas_sgemv(CblasRowMajor, CblasNoTrans, D, D,
1.0f, st, D, kh, 1, 0.0f, th, 1);
for (int i = 0; i < D; i++) th[i] = 0.5f * (vh[i] - th[i]);
cblas_sger(CblasRowMajor, D, D, 1.0f, kh, 1, th, 1, st, D);
float q_norm = cblas_snrm2(D, qh, 1);
cblas_sscal(D, 1.0f / (q_norm + 1e-8f) / sqrtf((float)D), qh, 1);
cblas_sgemv(CblasRowMajor, CblasNoTrans, D, D,
1.0f, st, D, qh, 1, 0.0f, o + h * D, 1);
}
}
}
// Causal attention — GQA, Accelerate BLAS (AMX)
static void causal_attention(
const float *q_buf, const float *k_cache, const float *v_cache,
float *o_buf, float *scores_buf,
int chunk_len, int kv_len, int start_pos,
int Hq, int Hkv, int D)
{
int group = Hq / Hkv;
float scale = 1.0f / sqrtf((float)D);
for (int hq = 0; hq < Hq; hq++) {
int hkv = hq / group;
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
chunk_len, kv_len, D, scale,
q_buf + hq * D, Hq * D,
k_cache + hkv * D, Hkv * D,
0.0f, scores_buf, kv_len);
for (int qi = 0; qi < chunk_len; qi++) {
float *row = scores_buf + qi * kv_len;
int abs_pos = start_pos + qi;
for (int ki = abs_pos + 1; ki < kv_len; ki++) row[ki] = -1e30f;
float mx = -1e30f;
for (int ki = 0; ki <= abs_pos; ki++)
if (row[ki] > mx) mx = row[ki];
float sum_exp = 0;
for (int ki = 0; ki <= abs_pos; ki++) {
row[ki] = expf(row[ki] - mx);
sum_exp += row[ki];
}
float inv = 1.0f / sum_exp;
for (int ki = 0; ki <= abs_pos; ki++) row[ki] *= inv;
for (int ki = abs_pos + 1; ki < kv_len; ki++) row[ki] = 0;
}
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
chunk_len, D, kv_len, 1.0f,
scores_buf, kv_len,
v_cache + hkv * D, Hkv * D,
0.0f, o_buf + hq * D, Hq * D);
}
}
// ═══════════════════════════════════════════════════════════════════
// ANE dispatch: write fp32 → fp16, eval on ANE, read fp16 → fp32
// ═══════════════════════════════════════════════════════════════════
static void ane_dispatch(ANEKernelHandle *k,
const float *in_f32, int in_elems,
float *out_f32, int out_elems)
{
_Float16 *in16 = (_Float16 *)malloc(in_elems * 2);
for (int i = 0; i < in_elems; i++) in16[i] = (_Float16)in_f32[i];
ane_write_input(k, 0, in16, in_elems * 2);
free(in16);
ane_eval(k);
_Float16 *out16 = (_Float16 *)malloc(out_elems * 2);
ane_read_output(k, 0, out16, out_elems * 2);
for (int i = 0; i < out_elems; i++) out_f32[i] = (float)out16[i];
free(out16);
}
// ═══════════════════════════════════════════════════════════════════
// Main Pipeline
// ═══════════════════════════════════════════════════════════════════
int main(int argc, char **argv) {
@autoreleasepool {
setbuf(stdout, NULL);
mach_timebase_info(&tb);
srand(42);
int S = 256;
if (argc > 1) S = atoi(argv[1]);
int CHUNK = S;
if (argc > 2) CHUNK = atoi(argv[2]);
if (S % CHUNK != 0) {
printf(" S=%d must be divisible by CHUNK=%d\n", S, CHUNK);
return 1;
}
int n_chunks = S / CHUNK;
int DIM = 4096, FFN = 12288;
int N_LAYERS = 32;
int N_DELTANET = 24, N_ATTN = 8;
int DN_H = 32, DN_D = 128;
int ATTN_HQ = 16, ATTN_HKV = 4, ATTN_D = 256;
int attn_layers[] = {3,7,11,15,19,23,27,31};
int is_attn[32] = {0};
for (int i = 0; i < N_ATTN; i++) is_attn[attn_layers[i]] = 1;
printf("═══════════════════════════════════════════════════════════════\n");
printf(" ANE Prefill Pipeline — Qwen 3.5 9B\n");
printf(" S=%d, CHUNK=%d (%d chunks), %d layers (%d DeltaNet + %d Attention)\n",
S, CHUNK, n_chunks, N_LAYERS, N_DELTANET, N_ATTN);
printf(" dim=%d, ffn=%d\n", DIM, FFN);
printf("═══════════════════════════════════════════════════════════════\n\n");
// ─── Compile ANE kernels ───
printf(" Compiling ANE kernels (sp=%d)...\n", CHUNK);
ane_bridge_init();
uint64_t t_compile_start = mach_absolute_time();
ANEKernelHandle *k_proj_large = compile_proj(8192, DIM, CHUNK);
ANEKernelHandle *k_proj_medium = compile_proj(DIM, DIM, CHUNK);
ANEKernelHandle *k_proj_small = compile_proj(1024, DIM, CHUNK);
ANEKernelHandle *k_ffn = compile_fused_ffn(DIM, FFN, CHUNK);
printf(" Compiled 4 kernels in %.0f ms\n", ticks_to_ms(mach_absolute_time() - t_compile_start));
if (!k_proj_large || !k_proj_medium || !k_proj_small || !k_ffn) {
printf(" FATAL: kernel compilation failed\n");
return 1;
}
// ─── Buffers ───
size_t buf_dim_C = DIM * CHUNK;
size_t buf_8192_C = 8192 * CHUNK;
size_t buf_1024_C = 1024 * CHUNK;
size_t buf_dim_S = DIM * S;
float *hidden = calloc(buf_dim_S, sizeof(float));
float *normed = calloc(buf_dim_C, sizeof(float));
float *proj_out = calloc(buf_dim_C, sizeof(float));
float *qkv_buf = calloc(buf_8192_C, sizeof(float));
float *gz_buf = calloc(buf_dim_C, sizeof(float));
float *q_buf = calloc(buf_8192_C, sizeof(float));
float *k_buf = calloc(buf_1024_C, sizeof(float));
float *v_buf = calloc(buf_1024_C, sizeof(float));
float *o_buf = calloc(buf_dim_C, sizeof(float));
float *ffn_out = calloc(buf_dim_C, sizeof(float));
float *dn_state = calloc(DN_H * DN_D * DN_D, sizeof(float));
float *dn_tmp = calloc(DN_H * DN_D, sizeof(float));
int kv_dim = ATTN_HKV * ATTN_D;
float *k_cache = calloc(S * kv_dim, sizeof(float));
float *v_cache = calloc(S * kv_dim, sizeof(float));
float *attn_scores = calloc(CHUNK * S, sizeof(float));
float *norm_w = calloc(DIM, sizeof(float));
for (int i = 0; i < DIM; i++) norm_w[i] = 1.0f;
for (int i = 0; i < (int)buf_dim_S; i++)
hidden[i] = ((float)(rand() % 1000) - 500) * 0.001f;
printf(" Buffers allocated\n\n");
// ─── Warmup ───
printf(" Warming up ANE (3 evals per kernel)...\n");
{
_Float16 *tmp = calloc(buf_8192_C, 2);
ane_write_input(k_proj_large, 0, tmp, DIM * CHUNK * 2);
ane_write_input(k_proj_medium, 0, tmp, DIM * CHUNK * 2);
ane_write_input(k_proj_small, 0, tmp, DIM * CHUNK * 2);
ane_write_input(k_ffn, 0, tmp, DIM * CHUNK * 2);
free(tmp);
for (int i = 0; i < 3; i++) {
ane_eval(k_proj_large);
ane_eval(k_proj_medium);
ane_eval(k_proj_small);
ane_eval(k_ffn);
}
}
printf(" Warm-up done\n\n");
// ─── Pipeline ───
printf(" Running %d-layer pipeline (S=%d, CHUNK=%d, %d chunks)...\n\n",
N_LAYERS, S, CHUNK, n_chunks);
double total_ane_proj_ms = 0, total_ane_ffn_ms = 0;
double total_cpu_norm_ms = 0, total_cpu_dn_rec_ms = 0;
double total_cpu_attn_ms = 0, total_cpu_misc_ms = 0;
int ane_dispatches = 0;
uint64_t t_pipeline_start = mach_absolute_time();
for (int c = 0; c < n_chunks; c++) {
int pos = c * CHUNK;
float *h_chunk = hidden + pos * DIM;
for (int l = 0; l < N_LAYERS; l++) {
uint64_t t0;
t0 = mach_absolute_time();
cpu_rmsnorm(normed, h_chunk, norm_w, DIM, CHUNK);
total_cpu_norm_ms += ticks_to_ms(mach_absolute_time() - t0);
if (!is_attn[l]) {
// ══════════ DeltaNet Layer ══════════
t0 = mach_absolute_time();
ane_dispatch(k_proj_large, normed, DIM * CHUNK, qkv_buf, 8192 * CHUNK);
total_ane_proj_ms += ticks_to_ms(mach_absolute_time() - t0);
ane_dispatches++;
t0 = mach_absolute_time();
ane_dispatch(k_proj_medium, normed, DIM * CHUNK, gz_buf, DIM * CHUNK);
total_ane_proj_ms += ticks_to_ms(mach_absolute_time() - t0);
ane_dispatches++;
t0 = mach_absolute_time();
deltanet_recurrence(qkv_buf, o_buf, dn_state, dn_tmp, CHUNK, DN_H, DN_D);
total_cpu_dn_rec_ms += ticks_to_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_dispatch(k_proj_medium, o_buf, DIM * CHUNK, proj_out, DIM * CHUNK);
total_ane_proj_ms += ticks_to_ms(mach_absolute_time() - t0);
ane_dispatches++;
} else {
// ══════════ Attention Layer ══════════
t0 = mach_absolute_time();
ane_dispatch(k_proj_large, normed, DIM * CHUNK, q_buf, 8192 * CHUNK);
total_ane_proj_ms += ticks_to_ms(mach_absolute_time() - t0);
ane_dispatches++;
t0 = mach_absolute_time();
ane_dispatch(k_proj_small, normed, DIM * CHUNK, k_buf, 1024 * CHUNK);
total_ane_proj_ms += ticks_to_ms(mach_absolute_time() - t0);
ane_dispatches++;
t0 = mach_absolute_time();
ane_dispatch(k_proj_small, normed, DIM * CHUNK, v_buf, 1024 * CHUNK);
total_ane_proj_ms += ticks_to_ms(mach_absolute_time() - t0);
ane_dispatches++;
memcpy(k_cache + pos * kv_dim, k_buf, CHUNK * kv_dim * sizeof(float));
memcpy(v_cache + pos * kv_dim, v_buf, CHUNK * kv_dim * sizeof(float));
t0 = mach_absolute_time();
causal_attention(q_buf, k_cache, v_cache, o_buf, attn_scores,
CHUNK, pos + CHUNK, pos, ATTN_HQ, ATTN_HKV, ATTN_D);
total_cpu_attn_ms += ticks_to_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_dispatch(k_proj_medium, o_buf, DIM * CHUNK, proj_out, DIM * CHUNK);
total_ane_proj_ms += ticks_to_ms(mach_absolute_time() - t0);
ane_dispatches++;
}
t0 = mach_absolute_time();
cpu_residual_add(h_chunk, proj_out, DIM, CHUNK);
total_cpu_misc_ms += ticks_to_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
cpu_rmsnorm(normed, h_chunk, norm_w, DIM, CHUNK);
total_cpu_norm_ms += ticks_to_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_dispatch(k_ffn, normed, DIM * CHUNK, ffn_out, DIM * CHUNK);
total_ane_ffn_ms += ticks_to_ms(mach_absolute_time() - t0);
ane_dispatches++;
t0 = mach_absolute_time();
cpu_residual_add(h_chunk, ffn_out, DIM, CHUNK);
total_cpu_misc_ms += ticks_to_ms(mach_absolute_time() - t0);
}
if ((c+1) % 4 == 0 || c == n_chunks - 1) {
double elapsed = ticks_to_ms(mach_absolute_time() - t_pipeline_start);
printf(" Chunk %2d/%d (pos %d-%d) done (%.0f ms elapsed)\n",
c+1, n_chunks, pos, pos+CHUNK-1, elapsed);
}
}
double total_ms = ticks_to_ms(mach_absolute_time() - t_pipeline_start);
// ─── Results ───
printf("\n═══════════════════════════════════════════════════════════════\n");
printf(" RESULTS — Qwen 3.5 9B Prefill (ANE + CPU), S=%d, CHUNK=%d\n", S, CHUNK);
printf("═══════════════════════════════════════════════════════════════\n\n");
printf(" Total wall time: %.1f ms\n", total_ms);
printf(" Tokens: %d\n", S);
printf(" Throughput: %.1f tok/s\n\n", S / (total_ms / 1000.0));
printf(" ┌──────────────────────────────────────────────────────┐\n");
printf(" │ Component │ Time (ms) │ %% of total │\n");
printf(" ├──────────────────────────────────────────────────────┤\n");
printf(" │ ANE projections │ %9.1f │ %6.1f%% │\n",
total_ane_proj_ms, total_ane_proj_ms / total_ms * 100);
printf(" │ ANE fused FFN │ %9.1f │ %6.1f%% │\n",
total_ane_ffn_ms, total_ane_ffn_ms / total_ms * 100);
printf(" │ CPU RMSNorm │ %9.1f │ %6.1f%% │\n",
total_cpu_norm_ms, total_cpu_norm_ms / total_ms * 100);
printf(" │ CPU DeltaNet recur. │ %9.1f │ %6.1f%% │\n",
total_cpu_dn_rec_ms, total_cpu_dn_rec_ms / total_ms * 100);
printf(" │ CPU causal attention │ %9.1f │ %6.1f%% │\n",
total_cpu_attn_ms, total_cpu_attn_ms / total_ms * 100);
printf(" │ CPU residual/misc │ %9.1f │ %6.1f%% │\n",
total_cpu_misc_ms, total_cpu_misc_ms / total_ms * 100);
printf(" └──────────────────────────────────────────────────────┘\n\n");
printf(" ANE dispatches: %d (%.2f ms avg)\n",
ane_dispatches, (total_ane_proj_ms + total_ane_ffn_ms) / ane_dispatches);
double proj_flops = 0;
proj_flops += 24.0 * (2.0*8192*4096 + 2*2.0*4096*4096) * S;
proj_flops += 8.0 * (2.0*8192*4096 + 2*2.0*1024*4096 + 2.0*4096*4096) * S;
double ffn_flops = 32.0 * 3.0 * 2.0 * 4096 * 12288 * S;
double total_flops = proj_flops + ffn_flops;
printf(" Total FLOPs: %.1f TFLOP\n", total_flops / 1e12);
printf(" Effective TFLOPS: %.2f\n", total_flops / (total_ms * 1e-3) / 1e12);
printf(" ANE-only TFLOPS: %.2f\n",
(proj_flops + ffn_flops) / ((total_ane_proj_ms + total_ane_ffn_ms) * 1e-3) / 1e12);
// Cleanup
ane_free(k_proj_large);
ane_free(k_proj_medium);
ane_free(k_proj_small);
ane_free(k_ffn);
free(hidden); free(normed); free(proj_out);
free(qkv_buf); free(gz_buf); free(q_buf); free(k_buf);
free(v_buf); free(o_buf); free(ffn_out);
free(dn_state); free(dn_tmp);
free(k_cache); free(v_cache);
free(attn_scores); free(norm_w);
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment