Last active
April 7, 2026 01:46
-
-
Save zoecarver/e6b9b7a002f3f9cece2a80c5e5557b8e to your computer and use it in GitHub Desktop.
rmsnorm + attention before
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| @ttl.operation(grid="auto") | |
| def rmsnorm_kernel(x, scaler, mean_scale, out): | |
| grid_cols, _ = ttl.grid_size(dims=2) | |
| seq_tiles = x.shape[0] // TILE | |
| tiles_per_core = -(-seq_tiles // grid_cols) | |
| x_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2) | |
| sc_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=1) | |
| ms_dfb = ttl.make_dataflow_buffer_like(mean_scale, shape=(1, 1), buffer_factor=1) | |
| sq_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2) | |
| red_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| acc_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| bcast_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2) | |
| rsq_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| out_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, 1), buffer_factor=2) | |
| @ttl.compute() | |
| def compute(): | |
| core_x, _ = ttl.node(dims=2) | |
| with sc_dfb.wait() as sc, ms_dfb.wait() as ms: | |
| for local_t in range(tiles_per_core): | |
| tile_idx = core_x * tiles_per_core + local_t | |
| if tile_idx < seq_tiles: | |
| # Pass 1: sum of squares across dim tiles | |
| with x_dfb.wait() as x0: | |
| with sq_dfb.reserve() as sq: | |
| sq.store(x0 * x0) | |
| with sq_dfb.wait() as sqv, red_dfb.reserve() as r: | |
| r.store(ttl.math.reduce_sum(sqv, sc, dims=[1])) | |
| with red_dfb.wait() as rv, acc_dfb.reserve() as acc: | |
| acc.store(rv) | |
| for j in range(dim_tiles - 1): | |
| with x_dfb.wait() as xj: | |
| with sq_dfb.reserve() as sq: | |
| sq.store(xj * xj) | |
| with sq_dfb.wait() as sqv, red_dfb.reserve() as r: | |
| r.store(ttl.math.reduce_sum(sqv, sc, dims=[1])) | |
| with red_dfb.wait() as rv, acc_dfb.wait() as av, acc_dfb.reserve() as new_acc: | |
| new_acc.store(av + rv) | |
| # broadcast, scale by 1/N, add eps, rsqrt | |
| with acc_dfb.wait() as total, bcast_dfb.reserve() as bc: | |
| bc.store(ttl.math.broadcast(total, bc, dims=[1])) | |
| with bcast_dfb.wait() as bv, red_dfb.reserve() as scaled: | |
| scaled.store(bv * ms + ttl.math.fill(bv, 1e-5)) | |
| with red_dfb.wait() as msq, rsq_dfb.reserve() as rsq: | |
| rsq.store(ttl.math.rsqrt(msq)) | |
| # Pass 2: x * rsqrt(mean(x^2) + eps) | |
| with rsq_dfb.wait() as rsqv: | |
| for j in range(dim_tiles): | |
| with x_dfb.wait() as xj, out_dfb.reserve() as o: | |
| o.store(xj * rsqv) | |
| @ttl.datamovement() | |
| def dm_read(): | |
| core_x, _ = ttl.node(dims=2) | |
| with sc_dfb.reserve() as blk1, ms_dfb.reserve() as blk2: | |
| tx1 = ttl.copy(scaler[0, 0], blk1) | |
| tx2 = ttl.copy(mean_scale[0, 0], blk2) | |
| tx1.wait(); tx2.wait() | |
| for local_t in range(tiles_per_core): | |
| tile_idx = core_x * tiles_per_core + local_t | |
| if tile_idx < seq_tiles: | |
| for j in range(dim_tiles): | |
| with x_dfb.reserve() as blk: | |
| tx = ttl.copy(x[tile_idx, j], blk); tx.wait() | |
| for j in range(dim_tiles): | |
| with x_dfb.reserve() as blk: | |
| tx = ttl.copy(x[tile_idx, j], blk); tx.wait() | |
| @ttl.datamovement() | |
| def dm_write(): | |
| core_x, _ = ttl.node(dims=2) | |
| for local_t in range(tiles_per_core): | |
| tile_idx = core_x * tiles_per_core + local_t | |
| if tile_idx < seq_tiles: | |
| for j in range(dim_tiles): | |
| with out_dfb.wait() as blk: | |
| tx = ttl.copy(blk, out[tile_idx, j]); tx.wait() | |
| @ttl.operation(grid=(n_head, 1)) | |
| def training_attention(Q, K, V, scale_tile, scaler, neg_inf_tile, | |
| zero_tile, zero_head, causal_mask, | |
| out, m_out, l_out): | |
| q_dfb = ttl.make_dataflow_buffer_like(Q, shape=(1, HEAD_TILES), buffer_factor=2) | |
| k_dfb = ttl.make_dataflow_buffer_like(K, shape=(1, HEAD_TILES), buffer_factor=2) | |
| v_dfb = ttl.make_dataflow_buffer_like(V, shape=(1, HEAD_TILES), buffer_factor=2) | |
| sc_dfb = ttl.make_dataflow_buffer_like(scale_tile, shape=(1, 1), buffer_factor=1) | |
| scaler_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=1) | |
| ninf_dfb = ttl.make_dataflow_buffer_like(neg_inf_tile, shape=(1, 1), buffer_factor=1) | |
| zero_dfb = ttl.make_dataflow_buffer_like(zero_tile, shape=(1, 1), buffer_factor=1) | |
| zh_dfb = ttl.make_dataflow_buffer_like(zero_head, shape=(1, HEAD_TILES), buffer_factor=1) | |
| mask_dfb = ttl.make_dataflow_buffer_like(causal_mask, shape=(1, 1), buffer_factor=2) | |
| kt_dfb = ttl.make_dataflow_buffer_like(K, shape=(HEAD_TILES, 1), buffer_factor=2) | |
| qk_dfb = ttl.make_dataflow_buffer_like(Q, shape=(1, 1), buffer_factor=2) | |
| scaled_dfb = ttl.make_dataflow_buffer_like(Q, shape=(1, 1), buffer_factor=2) | |
| cm_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| m_new_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| alpha_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| alpha_bc_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2) | |
| exp_dfb = ttl.make_dataflow_buffer_like(Q, shape=(1, 1), buffer_factor=2) | |
| cs_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| co_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2) | |
| m_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| l_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| o_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2) | |
| l_bc_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2) | |
| out_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2) | |
| m_out_dfb = ttl.make_dataflow_buffer_like(m_out, shape=(1, 1), buffer_factor=2) | |
| l_out_dfb = ttl.make_dataflow_buffer_like(l_out, shape=(1, 1), buffer_factor=2) | |
| @ttl.compute() | |
| def compute(): | |
| h, _ = ttl.node(dims=2) | |
| with sc_dfb.wait() as sc_blk, scaler_dfb.wait() as scaler_blk, \ | |
| ninf_dfb.wait() as ninf_blk, zero_dfb.wait() as zero_blk, \ | |
| zh_dfb.wait() as zh_blk: | |
| for q_row in range(seq_tiles): | |
| with q_dfb.wait() as q_blk: | |
| # Init running state: m=-inf, l=0, o=0 | |
| with m_dfb.reserve() as mi: | |
| mi.store(ninf_blk) | |
| with l_dfb.reserve() as li: | |
| li.store(zero_blk) | |
| with o_dfb.reserve() as oi: | |
| oi.store(zh_blk) | |
| for kv_col in range(q_row + 1): | |
| with k_dfb.wait() as kc, kt_dfb.reserve() as kt: | |
| kt.store(ttl.transpose(kc)) | |
| with kt_dfb.wait() as ktv, qk_dfb.reserve() as qk: | |
| qk.store(q_blk @ ktv) | |
| with qk_dfb.wait() as qkv, mask_dfb.wait() as mv: | |
| with scaled_dfb.reserve() as scd: | |
| scd.store(sc_blk * qkv + mv) | |
| # Online softmax | |
| with scaled_dfb.wait() as sd: | |
| with cm_dfb.reserve() as cm: | |
| cm.store(ttl.math.reduce_max(sd, scaler_blk, dims=[1])) | |
| with m_dfb.wait() as m_old: | |
| with cm_dfb.wait() as cm: | |
| with m_new_dfb.reserve() as mn: | |
| mn.store(ttl.math.max(m_old, cm)) | |
| with m_new_dfb.wait() as mn: | |
| with alpha_dfb.reserve() as alpha: | |
| alpha.store(ttl.math.exp(m_old - mn)) | |
| with exp_dfb.reserve() as ex: | |
| ex.store(ttl.math.exp(sd - mn)) | |
| with m_dfb.reserve() as m_next: | |
| m_next.store(mn) | |
| with exp_dfb.wait() as exp_blk: | |
| with cs_dfb.reserve() as cs: | |
| cs.store(ttl.math.reduce_sum(exp_blk, scaler_blk, dims=[1])) | |
| with alpha_dfb.wait() as alpha_blk: | |
| with l_dfb.wait() as l_old, cs_dfb.wait() as cs: | |
| with l_dfb.reserve() as l_new: | |
| l_new.store(alpha_blk * l_old + cs) | |
| with alpha_bc_dfb.reserve() as abc: | |
| abc.store(ttl.math.broadcast(alpha_blk, abc, dims=[1])) | |
| with alpha_bc_dfb.wait() as abc, o_dfb.wait() as o_old: | |
| with co_dfb.reserve() as co: | |
| co.store(abc * o_old) | |
| with co_dfb.wait() as co, v_dfb.wait() as vc: | |
| with o_dfb.reserve() as o_new: | |
| o_new.store(co + exp_blk @ vc) | |
| # Save m and l, normalize o | |
| with m_dfb.wait() as m_final: | |
| with m_out_dfb.reserve() as ms: | |
| ms.store(m_final) | |
| with l_dfb.wait() as l_final: | |
| with l_out_dfb.reserve() as ls: | |
| ls.store(l_final) | |
| with l_bc_dfb.reserve() as lbc: | |
| lbc.store(ttl.math.broadcast(l_final, lbc, dims=[1])) | |
| with o_dfb.wait() as o_final, l_bc_dfb.wait() as lbc: | |
| with out_dfb.reserve() as o: | |
| o.store(o_final / lbc) | |
| @ttl.datamovement() | |
| def dm_read(): | |
| h, _ = ttl.node(dims=2) | |
| kv_base = h * seq_tiles | |
| # Load constants | |
| with sc_dfb.reserve() as b: | |
| tx = ttl.copy(scale_tile[0, 0], b); tx.wait() | |
| with scaler_dfb.reserve() as b: | |
| tx = ttl.copy(scaler[0, 0], b); tx.wait() | |
| with ninf_dfb.reserve() as b: | |
| tx = ttl.copy(neg_inf_tile[0, 0], b); tx.wait() | |
| with zero_dfb.reserve() as b: | |
| tx = ttl.copy(zero_tile[0, 0], b); tx.wait() | |
| with zh_dfb.reserve() as b: | |
| tx = ttl.copy(zero_head[0, 0:HEAD_TILES], b); tx.wait() | |
| for q_row in range(seq_tiles): | |
| # Load Q for this row | |
| with q_dfb.reserve() as b: | |
| tx = ttl.copy(Q[kv_base + q_row:kv_base + q_row + 1, 0:HEAD_TILES], b) | |
| tx.wait() | |
| for kv_col in range(q_row + 1): | |
| # Load K, V, mask | |
| with k_dfb.reserve() as b: | |
| tx = ttl.copy(K[kv_base + kv_col:kv_base + kv_col + 1, 0:HEAD_TILES], b) | |
| tx.wait() | |
| with v_dfb.reserve() as b: | |
| tx = ttl.copy(V[kv_base + kv_col:kv_base + kv_col + 1, 0:HEAD_TILES], b) | |
| tx.wait() | |
| if kv_col == q_row: | |
| with mask_dfb.reserve() as b: | |
| tx = ttl.copy(causal_mask[0, 0], b); tx.wait() | |
| else: | |
| with mask_dfb.reserve() as b: | |
| tx = ttl.copy(zero_tile[0, 0], b); tx.wait() | |
| @ttl.datamovement() | |
| def dm_write(): | |
| h, _ = ttl.node(dims=2) | |
| out_base = h * seq_tiles | |
| for q_row in range(seq_tiles): | |
| with m_out_dfb.wait() as b: | |
| tx = ttl.copy(b, m_out[out_base + q_row, 0]); tx.wait() | |
| with l_out_dfb.wait() as b: | |
| tx = ttl.copy(b, l_out[out_base + q_row, 0]); tx.wait() | |
| with out_dfb.wait() as b: | |
| tx = ttl.copy(b, out[out_base + q_row:out_base + q_row + 1, 0:HEAD_TILES]) | |
| tx.wait() | |
| return training_attention |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment