Skip to content

Instantly share code, notes, and snippets.

@zoecarver
Last active April 7, 2026 01:46
Show Gist options
  • Select an option

  • Save zoecarver/e6b9b7a002f3f9cece2a80c5e5557b8e to your computer and use it in GitHub Desktop.

Select an option

Save zoecarver/e6b9b7a002f3f9cece2a80c5e5557b8e to your computer and use it in GitHub Desktop.
rmsnorm + attention before
@ttl.operation(grid="auto")
def rmsnorm_kernel(x, scaler, mean_scale, out):
grid_cols, _ = ttl.grid_size(dims=2)
seq_tiles = x.shape[0] // TILE
tiles_per_core = -(-seq_tiles // grid_cols)
x_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2)
sc_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=1)
ms_dfb = ttl.make_dataflow_buffer_like(mean_scale, shape=(1, 1), buffer_factor=1)
sq_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2)
red_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
acc_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
bcast_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2)
rsq_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
out_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, 1), buffer_factor=2)
@ttl.compute()
def compute():
core_x, _ = ttl.node(dims=2)
with sc_dfb.wait() as sc, ms_dfb.wait() as ms:
for local_t in range(tiles_per_core):
tile_idx = core_x * tiles_per_core + local_t
if tile_idx < seq_tiles:
# Pass 1: sum of squares across dim tiles
with x_dfb.wait() as x0:
with sq_dfb.reserve() as sq:
sq.store(x0 * x0)
with sq_dfb.wait() as sqv, red_dfb.reserve() as r:
r.store(ttl.math.reduce_sum(sqv, sc, dims=[1]))
with red_dfb.wait() as rv, acc_dfb.reserve() as acc:
acc.store(rv)
for j in range(dim_tiles - 1):
with x_dfb.wait() as xj:
with sq_dfb.reserve() as sq:
sq.store(xj * xj)
with sq_dfb.wait() as sqv, red_dfb.reserve() as r:
r.store(ttl.math.reduce_sum(sqv, sc, dims=[1]))
with red_dfb.wait() as rv, acc_dfb.wait() as av, acc_dfb.reserve() as new_acc:
new_acc.store(av + rv)
# broadcast, scale by 1/N, add eps, rsqrt
with acc_dfb.wait() as total, bcast_dfb.reserve() as bc:
bc.store(ttl.math.broadcast(total, bc, dims=[1]))
with bcast_dfb.wait() as bv, red_dfb.reserve() as scaled:
scaled.store(bv * ms + ttl.math.fill(bv, 1e-5))
with red_dfb.wait() as msq, rsq_dfb.reserve() as rsq:
rsq.store(ttl.math.rsqrt(msq))
# Pass 2: x * rsqrt(mean(x^2) + eps)
with rsq_dfb.wait() as rsqv:
for j in range(dim_tiles):
with x_dfb.wait() as xj, out_dfb.reserve() as o:
o.store(xj * rsqv)
@ttl.datamovement()
def dm_read():
core_x, _ = ttl.node(dims=2)
with sc_dfb.reserve() as blk1, ms_dfb.reserve() as blk2:
tx1 = ttl.copy(scaler[0, 0], blk1)
tx2 = ttl.copy(mean_scale[0, 0], blk2)
tx1.wait(); tx2.wait()
for local_t in range(tiles_per_core):
tile_idx = core_x * tiles_per_core + local_t
if tile_idx < seq_tiles:
for j in range(dim_tiles):
with x_dfb.reserve() as blk:
tx = ttl.copy(x[tile_idx, j], blk); tx.wait()
for j in range(dim_tiles):
with x_dfb.reserve() as blk:
tx = ttl.copy(x[tile_idx, j], blk); tx.wait()
@ttl.datamovement()
def dm_write():
core_x, _ = ttl.node(dims=2)
for local_t in range(tiles_per_core):
tile_idx = core_x * tiles_per_core + local_t
if tile_idx < seq_tiles:
for j in range(dim_tiles):
with out_dfb.wait() as blk:
tx = ttl.copy(blk, out[tile_idx, j]); tx.wait()
@ttl.operation(grid=(n_head, 1))
def training_attention(Q, K, V, scale_tile, scaler, neg_inf_tile,
zero_tile, zero_head, causal_mask,
out, m_out, l_out):
q_dfb = ttl.make_dataflow_buffer_like(Q, shape=(1, HEAD_TILES), buffer_factor=2)
k_dfb = ttl.make_dataflow_buffer_like(K, shape=(1, HEAD_TILES), buffer_factor=2)
v_dfb = ttl.make_dataflow_buffer_like(V, shape=(1, HEAD_TILES), buffer_factor=2)
sc_dfb = ttl.make_dataflow_buffer_like(scale_tile, shape=(1, 1), buffer_factor=1)
scaler_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=1)
ninf_dfb = ttl.make_dataflow_buffer_like(neg_inf_tile, shape=(1, 1), buffer_factor=1)
zero_dfb = ttl.make_dataflow_buffer_like(zero_tile, shape=(1, 1), buffer_factor=1)
zh_dfb = ttl.make_dataflow_buffer_like(zero_head, shape=(1, HEAD_TILES), buffer_factor=1)
mask_dfb = ttl.make_dataflow_buffer_like(causal_mask, shape=(1, 1), buffer_factor=2)
kt_dfb = ttl.make_dataflow_buffer_like(K, shape=(HEAD_TILES, 1), buffer_factor=2)
qk_dfb = ttl.make_dataflow_buffer_like(Q, shape=(1, 1), buffer_factor=2)
scaled_dfb = ttl.make_dataflow_buffer_like(Q, shape=(1, 1), buffer_factor=2)
cm_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
m_new_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
alpha_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
alpha_bc_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2)
exp_dfb = ttl.make_dataflow_buffer_like(Q, shape=(1, 1), buffer_factor=2)
cs_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
co_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2)
m_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
l_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
o_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2)
l_bc_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2)
out_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2)
m_out_dfb = ttl.make_dataflow_buffer_like(m_out, shape=(1, 1), buffer_factor=2)
l_out_dfb = ttl.make_dataflow_buffer_like(l_out, shape=(1, 1), buffer_factor=2)
@ttl.compute()
def compute():
h, _ = ttl.node(dims=2)
with sc_dfb.wait() as sc_blk, scaler_dfb.wait() as scaler_blk, \
ninf_dfb.wait() as ninf_blk, zero_dfb.wait() as zero_blk, \
zh_dfb.wait() as zh_blk:
for q_row in range(seq_tiles):
with q_dfb.wait() as q_blk:
# Init running state: m=-inf, l=0, o=0
with m_dfb.reserve() as mi:
mi.store(ninf_blk)
with l_dfb.reserve() as li:
li.store(zero_blk)
with o_dfb.reserve() as oi:
oi.store(zh_blk)
for kv_col in range(q_row + 1):
with k_dfb.wait() as kc, kt_dfb.reserve() as kt:
kt.store(ttl.transpose(kc))
with kt_dfb.wait() as ktv, qk_dfb.reserve() as qk:
qk.store(q_blk @ ktv)
with qk_dfb.wait() as qkv, mask_dfb.wait() as mv:
with scaled_dfb.reserve() as scd:
scd.store(sc_blk * qkv + mv)
# Online softmax
with scaled_dfb.wait() as sd:
with cm_dfb.reserve() as cm:
cm.store(ttl.math.reduce_max(sd, scaler_blk, dims=[1]))
with m_dfb.wait() as m_old:
with cm_dfb.wait() as cm:
with m_new_dfb.reserve() as mn:
mn.store(ttl.math.max(m_old, cm))
with m_new_dfb.wait() as mn:
with alpha_dfb.reserve() as alpha:
alpha.store(ttl.math.exp(m_old - mn))
with exp_dfb.reserve() as ex:
ex.store(ttl.math.exp(sd - mn))
with m_dfb.reserve() as m_next:
m_next.store(mn)
with exp_dfb.wait() as exp_blk:
with cs_dfb.reserve() as cs:
cs.store(ttl.math.reduce_sum(exp_blk, scaler_blk, dims=[1]))
with alpha_dfb.wait() as alpha_blk:
with l_dfb.wait() as l_old, cs_dfb.wait() as cs:
with l_dfb.reserve() as l_new:
l_new.store(alpha_blk * l_old + cs)
with alpha_bc_dfb.reserve() as abc:
abc.store(ttl.math.broadcast(alpha_blk, abc, dims=[1]))
with alpha_bc_dfb.wait() as abc, o_dfb.wait() as o_old:
with co_dfb.reserve() as co:
co.store(abc * o_old)
with co_dfb.wait() as co, v_dfb.wait() as vc:
with o_dfb.reserve() as o_new:
o_new.store(co + exp_blk @ vc)
# Save m and l, normalize o
with m_dfb.wait() as m_final:
with m_out_dfb.reserve() as ms:
ms.store(m_final)
with l_dfb.wait() as l_final:
with l_out_dfb.reserve() as ls:
ls.store(l_final)
with l_bc_dfb.reserve() as lbc:
lbc.store(ttl.math.broadcast(l_final, lbc, dims=[1]))
with o_dfb.wait() as o_final, l_bc_dfb.wait() as lbc:
with out_dfb.reserve() as o:
o.store(o_final / lbc)
@ttl.datamovement()
def dm_read():
h, _ = ttl.node(dims=2)
kv_base = h * seq_tiles
# Load constants
with sc_dfb.reserve() as b:
tx = ttl.copy(scale_tile[0, 0], b); tx.wait()
with scaler_dfb.reserve() as b:
tx = ttl.copy(scaler[0, 0], b); tx.wait()
with ninf_dfb.reserve() as b:
tx = ttl.copy(neg_inf_tile[0, 0], b); tx.wait()
with zero_dfb.reserve() as b:
tx = ttl.copy(zero_tile[0, 0], b); tx.wait()
with zh_dfb.reserve() as b:
tx = ttl.copy(zero_head[0, 0:HEAD_TILES], b); tx.wait()
for q_row in range(seq_tiles):
# Load Q for this row
with q_dfb.reserve() as b:
tx = ttl.copy(Q[kv_base + q_row:kv_base + q_row + 1, 0:HEAD_TILES], b)
tx.wait()
for kv_col in range(q_row + 1):
# Load K, V, mask
with k_dfb.reserve() as b:
tx = ttl.copy(K[kv_base + kv_col:kv_base + kv_col + 1, 0:HEAD_TILES], b)
tx.wait()
with v_dfb.reserve() as b:
tx = ttl.copy(V[kv_base + kv_col:kv_base + kv_col + 1, 0:HEAD_TILES], b)
tx.wait()
if kv_col == q_row:
with mask_dfb.reserve() as b:
tx = ttl.copy(causal_mask[0, 0], b); tx.wait()
else:
with mask_dfb.reserve() as b:
tx = ttl.copy(zero_tile[0, 0], b); tx.wait()
@ttl.datamovement()
def dm_write():
h, _ = ttl.node(dims=2)
out_base = h * seq_tiles
for q_row in range(seq_tiles):
with m_out_dfb.wait() as b:
tx = ttl.copy(b, m_out[out_base + q_row, 0]); tx.wait()
with l_out_dfb.wait() as b:
tx = ttl.copy(b, l_out[out_base + q_row, 0]); tx.wait()
with out_dfb.wait() as b:
tx = ttl.copy(b, out[out_base + q_row:out_base + q_row + 1, 0:HEAD_TILES])
tx.wait()
return training_attention
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment