Skip to content

Instantly share code, notes, and snippets.

@zoecarver
Created April 7, 2026 01:46
Show Gist options
  • Select an option

  • Save zoecarver/559691c360554f38caa2db0d8827b22e to your computer and use it in GitHub Desktop.

Select an option

Save zoecarver/559691c360554f38caa2db0d8827b22e to your computer and use it in GitHub Desktop.
rms norm + attention after.py
@ttl.compute()
def compute(x_dfb, sc_dfb, ms_dfb, sq_dfb, red_dfb, acc_dfb, bcast_dfb, rsq_dfb, out_dfb):
core_x, _ = ttl.node(dims=2)
sc = sc_dfb.pop_back_val()
ms = ms_dfb.pop_back_val()
for local_t in range(tiles_per_core):
tile_idx = core_x * tiles_per_core + local_t
if tile_idx < seq_tiles:
# Pass 1: sum of squares across dim tiles
x0 = x_dfb.pop_back_val()
sq_dfb.push(x0 * x0)
sqv = sq_dfb.pop_back_val()
red_dfb.push(ttl.math.reduce_sum(sqv, sc, dims=[1]))
rv = red_dfb.pop_back_val()
acc_dfb.push(rv)
for j in range(dim_tiles - 1):
xj = x_dfb.pop_back_val()
sq_dfb.push(xj * xj)
sqv = sq_dfb.pop_back_val()
red_dfb.push(ttl.math.reduce_sum(sqv, sc, dims=[1]))
rv = red_dfb.pop_back_val()
av = acc_dfb.pop_back_val()
acc_dfb.push(av + rv)
# broadcast, scale by 1/N, add eps, rsqrt
total = acc_dfb.pop_back_val()
bcast_dfb.push(ttl.math.broadcast(total, bcast_dfb, dims=[1]))
bv = bcast_dfb.pop_back_val()
red_dfb.push(bv * ms + ttl.math.fill(bv, 1e-5))
msq = red_dfb.pop_back_val()
rsq_dfb.push(ttl.math.rsqrt(msq))
# Pass 2: x * rsqrt(mean(x^2) + eps)
rsqv = rsq_dfb.pop_back_val()
for j in range(dim_tiles):
xj = x_dfb.pop_back_val()
out_dfb.push(xj * rsqv)
@ttl.datamovement()
def dm_read(x_dfb, sc_dfb, ms_dfb):
core_x, _ = ttl.node(dims=2)
tx1 = ttl.copy(scaler[0, 0], sc_dfb.reserve())
tx2 = ttl.copy(mean_scale[0, 0], ms_dfb.reserve())
tx1.wait(); tx2.wait()
for local_t in range(tiles_per_core):
tile_idx = core_x * tiles_per_core + local_t
if tile_idx < seq_tiles:
for j in range(dim_tiles):
tx = ttl.copy(x[tile_idx, j], x_dfb.reserve()); tx.wait()
for j in range(dim_tiles):
tx = ttl.copy(x[tile_idx, j], x_dfb.reserve()); tx.wait()
@ttl.datamovement()
def dm_write(out_dfb):
core_x, _ = ttl.node(dims=2)
for local_t in range(tiles_per_core):
tile_idx = core_x * tiles_per_core + local_t
if tile_idx < seq_tiles:
for j in range(dim_tiles):
tx = ttl.copy(out_dfb.pop_back_val(), out[tile_idx, j]); tx.wait()
@ttl.operation(grid="auto")
def rmsnorm_kernel(x, scaler, mean_scale, out):
grid_cols, _ = ttl.grid_size(dims=2)
seq_tiles = x.shape[0] // TILE
tiles_per_core = -(-seq_tiles // grid_cols)
x_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2)
sc_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=1)
ms_dfb = ttl.make_dataflow_buffer_like(mean_scale, shape=(1, 1), buffer_factor=1)
sq_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2)
red_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
acc_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
bcast_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2)
rsq_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
out_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, 1), buffer_factor=2)
dm_read(x_dfb, sc_dfb, ms_dfb)
compute(x_dfb, sc_dfb, ms_dfb, sq_dfb, red_dfb, acc_dfb, bcast_dfb, rsq_dfb, out_dfb)
dm_write(out_dfb)
@ttl.compute()
def compute(q_dfb, k_dfb, v_dfb, sc_dfb, scaler_dfb, ninf_dfb, zero_dfb, zh_dfb,
mask_dfb, kt_dfb, qk_dfb, scaled_dfb, cm_dfb, m_new_dfb, alpha_dfb,
alpha_bc_dfb, exp_dfb, cs_dfb, co_dfb, m_dfb, l_dfb, o_dfb,
l_bc_dfb, out_dfb, m_out_dfb, l_out_dfb):
h, _ = ttl.node(dims=2)
sc_blk = sc_dfb.pop_back_val()
scaler_blk = scaler_dfb.pop_back_val()
ninf_blk = ninf_dfb.pop_back_val()
zero_blk = zero_dfb.pop_back_val()
zh_blk = zh_dfb.pop_back_val()
for q_row in range(seq_tiles):
q_blk = q_dfb.pop_back_val()
# Init running state: m=-inf, l=0, o=0
m_dfb.push(ninf_blk)
l_dfb.push(zero_blk)
o_dfb.push(zh_blk)
for kv_col in range(q_row + 1):
kc = k_dfb.pop_back_val()
kt_dfb.push(ttl.transpose(kc))
ktv = kt_dfb.pop_back_val()
qk_dfb.push(q_blk @ ktv)
qkv = qk_dfb.pop_back_val()
mv = mask_dfb.pop_back_val()
scaled_dfb.push(sc_blk * qkv + mv)
# Online softmax
sd = scaled_dfb.pop_back_val()
cm_dfb.push(ttl.math.reduce_max(sd, scaler_blk, dims=[1]))
m_old = m_dfb.pop_back_val()
cm = cm_dfb.pop_back_val()
m_new_dfb.push(ttl.math.max(m_old, cm))
mn = m_new_dfb.pop_back_val()
alpha_dfb.push(ttl.math.exp(m_old - mn))
exp_dfb.push(ttl.math.exp(sd - mn))
m_dfb.push(mn)
exp_blk = exp_dfb.pop_back_val()
cs_dfb.push(ttl.math.reduce_sum(exp_blk, scaler_blk, dims=[1]))
alpha_blk = alpha_dfb.pop_back_val()
l_old = l_dfb.pop_back_val()
cs = cs_dfb.pop_back_val()
l_dfb.push(alpha_blk * l_old + cs)
alpha_bc_dfb.push(ttl.math.broadcast(alpha_blk, alpha_bc_dfb, dims=[1]))
abc = alpha_bc_dfb.pop_back_val()
o_old = o_dfb.pop_back_val()
co_dfb.push(abc * o_old)
co = co_dfb.pop_back_val()
vc = v_dfb.pop_back_val()
o_dfb.push(co + exp_blk @ vc)
# Save m and l, normalize o
m_final = m_dfb.pop_back_val()
m_out_dfb.push(m_final)
l_final = l_dfb.pop_back_val()
l_out_dfb.push(l_final)
l_bc_dfb.push(ttl.math.broadcast(l_final, l_bc_dfb, dims=[1]))
o_final = o_dfb.pop_back_val()
lbc = l_bc_dfb.pop_back_val()
out_dfb.push(o_final / lbc)
@ttl.datamovement()
def dm_read(q_dfb, k_dfb, v_dfb, sc_dfb, scaler_dfb, ninf_dfb, zero_dfb, zh_dfb, mask_dfb):
h, _ = ttl.node(dims=2)
kv_base = h * seq_tiles
# Load constants
tx = ttl.copy(scale_tile[0, 0], sc_dfb.reserve()); tx.wait()
tx = ttl.copy(scaler[0, 0], scaler_dfb.reserve()); tx.wait()
tx = ttl.copy(neg_inf_tile[0, 0], ninf_dfb.reserve()); tx.wait()
tx = ttl.copy(zero_tile[0, 0], zero_dfb.reserve()); tx.wait()
tx = ttl.copy(zero_head[0, 0:HEAD_TILES], zh_dfb.reserve()); tx.wait()
for q_row in range(seq_tiles):
tx = ttl.copy(Q[kv_base + q_row:kv_base + q_row + 1, 0:HEAD_TILES], q_dfb.reserve())
tx.wait()
for kv_col in range(q_row + 1):
tx = ttl.copy(K[kv_base + kv_col:kv_base + kv_col + 1, 0:HEAD_TILES], k_dfb.reserve())
tx.wait()
tx = ttl.copy(V[kv_base + kv_col:kv_base + kv_col + 1, 0:HEAD_TILES], v_dfb.reserve())
tx.wait()
if kv_col == q_row:
tx = ttl.copy(causal_mask[0, 0], mask_dfb.reserve()); tx.wait()
else:
tx = ttl.copy(zero_tile[0, 0], mask_dfb.reserve()); tx.wait()
@ttl.datamovement()
def dm_write(out_dfb, m_out_dfb, l_out_dfb):
h, _ = ttl.node(dims=2)
out_base = h * seq_tiles
for q_row in range(seq_tiles):
tx = ttl.copy(m_out_dfb.pop_back_val(), m_out[out_base + q_row, 0]); tx.wait()
tx = ttl.copy(l_out_dfb.pop_back_val(), l_out[out_base + q_row, 0]); tx.wait()
tx = ttl.copy(out_dfb.pop_back_val(), out[out_base + q_row:out_base + q_row + 1, 0:HEAD_TILES])
tx.wait()
@ttl.operation(grid=(n_head, 1))
def training_attention(Q, K, V, scale_tile, scaler, neg_inf_tile,
zero_tile, zero_head, causal_mask,
out, m_out, l_out):
q_dfb = ttl.make_dataflow_buffer_like(Q, shape=(1, HEAD_TILES), buffer_factor=2)
k_dfb = ttl.make_dataflow_buffer_like(K, shape=(1, HEAD_TILES), buffer_factor=2)
v_dfb = ttl.make_dataflow_buffer_like(V, shape=(1, HEAD_TILES), buffer_factor=2)
sc_dfb = ttl.make_dataflow_buffer_like(scale_tile, shape=(1, 1), buffer_factor=1)
scaler_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=1)
ninf_dfb = ttl.make_dataflow_buffer_like(neg_inf_tile, shape=(1, 1), buffer_factor=1)
zero_dfb = ttl.make_dataflow_buffer_like(zero_tile, shape=(1, 1), buffer_factor=1)
zh_dfb = ttl.make_dataflow_buffer_like(zero_head, shape=(1, HEAD_TILES), buffer_factor=1)
mask_dfb = ttl.make_dataflow_buffer_like(causal_mask, shape=(1, 1), buffer_factor=2)
kt_dfb = ttl.make_dataflow_buffer_like(K, shape=(HEAD_TILES, 1), buffer_factor=2)
qk_dfb = ttl.make_dataflow_buffer_like(Q, shape=(1, 1), buffer_factor=2)
scaled_dfb = ttl.make_dataflow_buffer_like(Q, shape=(1, 1), buffer_factor=2)
cm_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
m_new_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
alpha_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
alpha_bc_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2)
exp_dfb = ttl.make_dataflow_buffer_like(Q, shape=(1, 1), buffer_factor=2)
cs_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
co_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2)
m_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
l_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
o_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2)
l_bc_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2)
out_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2)
m_out_dfb = ttl.make_dataflow_buffer_like(m_out, shape=(1, 1), buffer_factor=2)
l_out_dfb = ttl.make_dataflow_buffer_like(l_out, shape=(1, 1), buffer_factor=2)
dm_read(q_dfb, k_dfb, v_dfb, sc_dfb, scaler_dfb, ninf_dfb, zero_dfb, zh_dfb, mask_dfb)
compute(q_dfb, k_dfb, v_dfb, sc_dfb, scaler_dfb, ninf_dfb, zero_dfb, zh_dfb,
mask_dfb, kt_dfb, qk_dfb, scaled_dfb, cm_dfb, m_new_dfb, alpha_dfb,
alpha_bc_dfb, exp_dfb, cs_dfb, co_dfb, m_dfb, l_dfb, o_dfb,
l_bc_dfb, out_dfb, m_out_dfb, l_out_dfb)
dm_write(out_dfb, m_out_dfb, l_out_dfb)
# ========================================================================
# ============================== FUSED ===================================
# ========================================================================
@ttl.operation(grid=(n_head, 1))
def fused_rmsnorm_attention(x, scaler, mean_scale,
K, V, scale_tile, attn_scaler, neg_inf_tile,
zero_tile, zero_head, causal_mask,
out, m_out, l_out):
grid_cols, _ = ttl.grid_size(dims=2)
seq_tiles = x.shape[0] // TILE
tiles_per_core = -(-seq_tiles // grid_cols)
# === RMSNorm DFBs ===
x_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2)
sc_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=1)
ms_dfb = ttl.make_dataflow_buffer_like(mean_scale, shape=(1, 1), buffer_factor=1)
sq_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2)
red_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
acc_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
bcast_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2)
rsq_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
# === SHARED: rmsnorm output IS attention's Q input ===
q_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2)
# rmsnorm writes into q_dfb, attention reads from q_dfb — no DRAM round trip
# === Attention DFBs ===
k_dfb = ttl.make_dataflow_buffer_like(K, shape=(1, HEAD_TILES), buffer_factor=2)
v_dfb = ttl.make_dataflow_buffer_like(V, shape=(1, HEAD_TILES), buffer_factor=2)
attn_sc_dfb = ttl.make_dataflow_buffer_like(scale_tile, shape=(1, 1), buffer_factor=1)
attn_scaler_dfb = ttl.make_dataflow_buffer_like(attn_scaler, shape=(1, 1), buffer_factor=1)
ninf_dfb = ttl.make_dataflow_buffer_like(neg_inf_tile, shape=(1, 1), buffer_factor=1)
zero_dfb = ttl.make_dataflow_buffer_like(zero_tile, shape=(1, 1), buffer_factor=1)
zh_dfb = ttl.make_dataflow_buffer_like(zero_head, shape=(1, HEAD_TILES), buffer_factor=1)
mask_dfb = ttl.make_dataflow_buffer_like(causal_mask, shape=(1, 1), buffer_factor=2)
kt_dfb = ttl.make_dataflow_buffer_like(K, shape=(HEAD_TILES, 1), buffer_factor=2)
qk_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, 1), buffer_factor=2)
scaled_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, 1), buffer_factor=2)
cm_dfb = ttl.make_dataflow_buffer_like(attn_scaler, shape=(1, 1), buffer_factor=2)
m_new_dfb = ttl.make_dataflow_buffer_like(attn_scaler, shape=(1, 1), buffer_factor=2)
alpha_dfb = ttl.make_dataflow_buffer_like(attn_scaler, shape=(1, 1), buffer_factor=2)
alpha_bc_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2)
exp_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, 1), buffer_factor=2)
cs_dfb = ttl.make_dataflow_buffer_like(attn_scaler, shape=(1, 1), buffer_factor=2)
co_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2)
m_dfb = ttl.make_dataflow_buffer_like(attn_scaler, shape=(1, 1), buffer_factor=2)
l_dfb = ttl.make_dataflow_buffer_like(attn_scaler, shape=(1, 1), buffer_factor=2)
o_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2)
l_bc_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2)
out_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2)
m_out_dfb = ttl.make_dataflow_buffer_like(m_out, shape=(1, 1), buffer_factor=2)
l_out_dfb = ttl.make_dataflow_buffer_like(l_out, shape=(1, 1), buffer_factor=2)
# === Wire it up: rmsnorm → attention, no DRAM in between ===
# Datamovement: rmsnorm reads, attention reads K/V/constants
rmsnorm_dm_read(x_dfb, sc_dfb, ms_dfb)
attn_dm_read(q_dfb, k_dfb, v_dfb, attn_sc_dfb, attn_scaler_dfb,
ninf_dfb, zero_dfb, zh_dfb, mask_dfb)
# Compute: rmsnorm produces into q_dfb, attention consumes from q_dfb
rmsnorm_compute(x_dfb, sc_dfb, ms_dfb, sq_dfb, red_dfb, acc_dfb,
bcast_dfb, rsq_dfb, q_dfb) # note: q_dfb is rmsnorm's "out_dfb"
attn_compute(q_dfb, k_dfb, v_dfb, attn_sc_dfb, attn_scaler_dfb,
ninf_dfb, zero_dfb, zh_dfb, mask_dfb,
kt_dfb, qk_dfb, scaled_dfb, cm_dfb, m_new_dfb, alpha_dfb,
alpha_bc_dfb, exp_dfb, cs_dfb, co_dfb,
m_dfb, l_dfb, o_dfb, l_bc_dfb, out_dfb, m_out_dfb, l_out_dfb)
# Only attention writes to DRAM
attn_dm_write(out_dfb, m_out_dfb, l_out_dfb)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment