Created
April 7, 2026 01:46
-
-
Save zoecarver/559691c360554f38caa2db0d8827b22e to your computer and use it in GitHub Desktop.
rms norm + attention after.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| @ttl.compute() | |
| def compute(x_dfb, sc_dfb, ms_dfb, sq_dfb, red_dfb, acc_dfb, bcast_dfb, rsq_dfb, out_dfb): | |
| core_x, _ = ttl.node(dims=2) | |
| sc = sc_dfb.pop_back_val() | |
| ms = ms_dfb.pop_back_val() | |
| for local_t in range(tiles_per_core): | |
| tile_idx = core_x * tiles_per_core + local_t | |
| if tile_idx < seq_tiles: | |
| # Pass 1: sum of squares across dim tiles | |
| x0 = x_dfb.pop_back_val() | |
| sq_dfb.push(x0 * x0) | |
| sqv = sq_dfb.pop_back_val() | |
| red_dfb.push(ttl.math.reduce_sum(sqv, sc, dims=[1])) | |
| rv = red_dfb.pop_back_val() | |
| acc_dfb.push(rv) | |
| for j in range(dim_tiles - 1): | |
| xj = x_dfb.pop_back_val() | |
| sq_dfb.push(xj * xj) | |
| sqv = sq_dfb.pop_back_val() | |
| red_dfb.push(ttl.math.reduce_sum(sqv, sc, dims=[1])) | |
| rv = red_dfb.pop_back_val() | |
| av = acc_dfb.pop_back_val() | |
| acc_dfb.push(av + rv) | |
| # broadcast, scale by 1/N, add eps, rsqrt | |
| total = acc_dfb.pop_back_val() | |
| bcast_dfb.push(ttl.math.broadcast(total, bcast_dfb, dims=[1])) | |
| bv = bcast_dfb.pop_back_val() | |
| red_dfb.push(bv * ms + ttl.math.fill(bv, 1e-5)) | |
| msq = red_dfb.pop_back_val() | |
| rsq_dfb.push(ttl.math.rsqrt(msq)) | |
| # Pass 2: x * rsqrt(mean(x^2) + eps) | |
| rsqv = rsq_dfb.pop_back_val() | |
| for j in range(dim_tiles): | |
| xj = x_dfb.pop_back_val() | |
| out_dfb.push(xj * rsqv) | |
| @ttl.datamovement() | |
| def dm_read(x_dfb, sc_dfb, ms_dfb): | |
| core_x, _ = ttl.node(dims=2) | |
| tx1 = ttl.copy(scaler[0, 0], sc_dfb.reserve()) | |
| tx2 = ttl.copy(mean_scale[0, 0], ms_dfb.reserve()) | |
| tx1.wait(); tx2.wait() | |
| for local_t in range(tiles_per_core): | |
| tile_idx = core_x * tiles_per_core + local_t | |
| if tile_idx < seq_tiles: | |
| for j in range(dim_tiles): | |
| tx = ttl.copy(x[tile_idx, j], x_dfb.reserve()); tx.wait() | |
| for j in range(dim_tiles): | |
| tx = ttl.copy(x[tile_idx, j], x_dfb.reserve()); tx.wait() | |
| @ttl.datamovement() | |
| def dm_write(out_dfb): | |
| core_x, _ = ttl.node(dims=2) | |
| for local_t in range(tiles_per_core): | |
| tile_idx = core_x * tiles_per_core + local_t | |
| if tile_idx < seq_tiles: | |
| for j in range(dim_tiles): | |
| tx = ttl.copy(out_dfb.pop_back_val(), out[tile_idx, j]); tx.wait() | |
| @ttl.operation(grid="auto") | |
| def rmsnorm_kernel(x, scaler, mean_scale, out): | |
| grid_cols, _ = ttl.grid_size(dims=2) | |
| seq_tiles = x.shape[0] // TILE | |
| tiles_per_core = -(-seq_tiles // grid_cols) | |
| x_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2) | |
| sc_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=1) | |
| ms_dfb = ttl.make_dataflow_buffer_like(mean_scale, shape=(1, 1), buffer_factor=1) | |
| sq_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2) | |
| red_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| acc_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| bcast_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2) | |
| rsq_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| out_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, 1), buffer_factor=2) | |
| dm_read(x_dfb, sc_dfb, ms_dfb) | |
| compute(x_dfb, sc_dfb, ms_dfb, sq_dfb, red_dfb, acc_dfb, bcast_dfb, rsq_dfb, out_dfb) | |
| dm_write(out_dfb) | |
| @ttl.compute() | |
| def compute(q_dfb, k_dfb, v_dfb, sc_dfb, scaler_dfb, ninf_dfb, zero_dfb, zh_dfb, | |
| mask_dfb, kt_dfb, qk_dfb, scaled_dfb, cm_dfb, m_new_dfb, alpha_dfb, | |
| alpha_bc_dfb, exp_dfb, cs_dfb, co_dfb, m_dfb, l_dfb, o_dfb, | |
| l_bc_dfb, out_dfb, m_out_dfb, l_out_dfb): | |
| h, _ = ttl.node(dims=2) | |
| sc_blk = sc_dfb.pop_back_val() | |
| scaler_blk = scaler_dfb.pop_back_val() | |
| ninf_blk = ninf_dfb.pop_back_val() | |
| zero_blk = zero_dfb.pop_back_val() | |
| zh_blk = zh_dfb.pop_back_val() | |
| for q_row in range(seq_tiles): | |
| q_blk = q_dfb.pop_back_val() | |
| # Init running state: m=-inf, l=0, o=0 | |
| m_dfb.push(ninf_blk) | |
| l_dfb.push(zero_blk) | |
| o_dfb.push(zh_blk) | |
| for kv_col in range(q_row + 1): | |
| kc = k_dfb.pop_back_val() | |
| kt_dfb.push(ttl.transpose(kc)) | |
| ktv = kt_dfb.pop_back_val() | |
| qk_dfb.push(q_blk @ ktv) | |
| qkv = qk_dfb.pop_back_val() | |
| mv = mask_dfb.pop_back_val() | |
| scaled_dfb.push(sc_blk * qkv + mv) | |
| # Online softmax | |
| sd = scaled_dfb.pop_back_val() | |
| cm_dfb.push(ttl.math.reduce_max(sd, scaler_blk, dims=[1])) | |
| m_old = m_dfb.pop_back_val() | |
| cm = cm_dfb.pop_back_val() | |
| m_new_dfb.push(ttl.math.max(m_old, cm)) | |
| mn = m_new_dfb.pop_back_val() | |
| alpha_dfb.push(ttl.math.exp(m_old - mn)) | |
| exp_dfb.push(ttl.math.exp(sd - mn)) | |
| m_dfb.push(mn) | |
| exp_blk = exp_dfb.pop_back_val() | |
| cs_dfb.push(ttl.math.reduce_sum(exp_blk, scaler_blk, dims=[1])) | |
| alpha_blk = alpha_dfb.pop_back_val() | |
| l_old = l_dfb.pop_back_val() | |
| cs = cs_dfb.pop_back_val() | |
| l_dfb.push(alpha_blk * l_old + cs) | |
| alpha_bc_dfb.push(ttl.math.broadcast(alpha_blk, alpha_bc_dfb, dims=[1])) | |
| abc = alpha_bc_dfb.pop_back_val() | |
| o_old = o_dfb.pop_back_val() | |
| co_dfb.push(abc * o_old) | |
| co = co_dfb.pop_back_val() | |
| vc = v_dfb.pop_back_val() | |
| o_dfb.push(co + exp_blk @ vc) | |
| # Save m and l, normalize o | |
| m_final = m_dfb.pop_back_val() | |
| m_out_dfb.push(m_final) | |
| l_final = l_dfb.pop_back_val() | |
| l_out_dfb.push(l_final) | |
| l_bc_dfb.push(ttl.math.broadcast(l_final, l_bc_dfb, dims=[1])) | |
| o_final = o_dfb.pop_back_val() | |
| lbc = l_bc_dfb.pop_back_val() | |
| out_dfb.push(o_final / lbc) | |
| @ttl.datamovement() | |
| def dm_read(q_dfb, k_dfb, v_dfb, sc_dfb, scaler_dfb, ninf_dfb, zero_dfb, zh_dfb, mask_dfb): | |
| h, _ = ttl.node(dims=2) | |
| kv_base = h * seq_tiles | |
| # Load constants | |
| tx = ttl.copy(scale_tile[0, 0], sc_dfb.reserve()); tx.wait() | |
| tx = ttl.copy(scaler[0, 0], scaler_dfb.reserve()); tx.wait() | |
| tx = ttl.copy(neg_inf_tile[0, 0], ninf_dfb.reserve()); tx.wait() | |
| tx = ttl.copy(zero_tile[0, 0], zero_dfb.reserve()); tx.wait() | |
| tx = ttl.copy(zero_head[0, 0:HEAD_TILES], zh_dfb.reserve()); tx.wait() | |
| for q_row in range(seq_tiles): | |
| tx = ttl.copy(Q[kv_base + q_row:kv_base + q_row + 1, 0:HEAD_TILES], q_dfb.reserve()) | |
| tx.wait() | |
| for kv_col in range(q_row + 1): | |
| tx = ttl.copy(K[kv_base + kv_col:kv_base + kv_col + 1, 0:HEAD_TILES], k_dfb.reserve()) | |
| tx.wait() | |
| tx = ttl.copy(V[kv_base + kv_col:kv_base + kv_col + 1, 0:HEAD_TILES], v_dfb.reserve()) | |
| tx.wait() | |
| if kv_col == q_row: | |
| tx = ttl.copy(causal_mask[0, 0], mask_dfb.reserve()); tx.wait() | |
| else: | |
| tx = ttl.copy(zero_tile[0, 0], mask_dfb.reserve()); tx.wait() | |
| @ttl.datamovement() | |
| def dm_write(out_dfb, m_out_dfb, l_out_dfb): | |
| h, _ = ttl.node(dims=2) | |
| out_base = h * seq_tiles | |
| for q_row in range(seq_tiles): | |
| tx = ttl.copy(m_out_dfb.pop_back_val(), m_out[out_base + q_row, 0]); tx.wait() | |
| tx = ttl.copy(l_out_dfb.pop_back_val(), l_out[out_base + q_row, 0]); tx.wait() | |
| tx = ttl.copy(out_dfb.pop_back_val(), out[out_base + q_row:out_base + q_row + 1, 0:HEAD_TILES]) | |
| tx.wait() | |
| @ttl.operation(grid=(n_head, 1)) | |
| def training_attention(Q, K, V, scale_tile, scaler, neg_inf_tile, | |
| zero_tile, zero_head, causal_mask, | |
| out, m_out, l_out): | |
| q_dfb = ttl.make_dataflow_buffer_like(Q, shape=(1, HEAD_TILES), buffer_factor=2) | |
| k_dfb = ttl.make_dataflow_buffer_like(K, shape=(1, HEAD_TILES), buffer_factor=2) | |
| v_dfb = ttl.make_dataflow_buffer_like(V, shape=(1, HEAD_TILES), buffer_factor=2) | |
| sc_dfb = ttl.make_dataflow_buffer_like(scale_tile, shape=(1, 1), buffer_factor=1) | |
| scaler_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=1) | |
| ninf_dfb = ttl.make_dataflow_buffer_like(neg_inf_tile, shape=(1, 1), buffer_factor=1) | |
| zero_dfb = ttl.make_dataflow_buffer_like(zero_tile, shape=(1, 1), buffer_factor=1) | |
| zh_dfb = ttl.make_dataflow_buffer_like(zero_head, shape=(1, HEAD_TILES), buffer_factor=1) | |
| mask_dfb = ttl.make_dataflow_buffer_like(causal_mask, shape=(1, 1), buffer_factor=2) | |
| kt_dfb = ttl.make_dataflow_buffer_like(K, shape=(HEAD_TILES, 1), buffer_factor=2) | |
| qk_dfb = ttl.make_dataflow_buffer_like(Q, shape=(1, 1), buffer_factor=2) | |
| scaled_dfb = ttl.make_dataflow_buffer_like(Q, shape=(1, 1), buffer_factor=2) | |
| cm_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| m_new_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| alpha_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| alpha_bc_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2) | |
| exp_dfb = ttl.make_dataflow_buffer_like(Q, shape=(1, 1), buffer_factor=2) | |
| cs_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| co_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2) | |
| m_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| l_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| o_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2) | |
| l_bc_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2) | |
| out_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2) | |
| m_out_dfb = ttl.make_dataflow_buffer_like(m_out, shape=(1, 1), buffer_factor=2) | |
| l_out_dfb = ttl.make_dataflow_buffer_like(l_out, shape=(1, 1), buffer_factor=2) | |
| dm_read(q_dfb, k_dfb, v_dfb, sc_dfb, scaler_dfb, ninf_dfb, zero_dfb, zh_dfb, mask_dfb) | |
| compute(q_dfb, k_dfb, v_dfb, sc_dfb, scaler_dfb, ninf_dfb, zero_dfb, zh_dfb, | |
| mask_dfb, kt_dfb, qk_dfb, scaled_dfb, cm_dfb, m_new_dfb, alpha_dfb, | |
| alpha_bc_dfb, exp_dfb, cs_dfb, co_dfb, m_dfb, l_dfb, o_dfb, | |
| l_bc_dfb, out_dfb, m_out_dfb, l_out_dfb) | |
| dm_write(out_dfb, m_out_dfb, l_out_dfb) | |
| # ======================================================================== | |
| # ============================== FUSED =================================== | |
| # ======================================================================== | |
| @ttl.operation(grid=(n_head, 1)) | |
| def fused_rmsnorm_attention(x, scaler, mean_scale, | |
| K, V, scale_tile, attn_scaler, neg_inf_tile, | |
| zero_tile, zero_head, causal_mask, | |
| out, m_out, l_out): | |
| grid_cols, _ = ttl.grid_size(dims=2) | |
| seq_tiles = x.shape[0] // TILE | |
| tiles_per_core = -(-seq_tiles // grid_cols) | |
| # === RMSNorm DFBs === | |
| x_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2) | |
| sc_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=1) | |
| ms_dfb = ttl.make_dataflow_buffer_like(mean_scale, shape=(1, 1), buffer_factor=1) | |
| sq_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2) | |
| red_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| acc_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| bcast_dfb = ttl.make_dataflow_buffer_like(x, shape=(1, 1), buffer_factor=2) | |
| rsq_dfb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| # === SHARED: rmsnorm output IS attention's Q input === | |
| q_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2) | |
| # rmsnorm writes into q_dfb, attention reads from q_dfb — no DRAM round trip | |
| # === Attention DFBs === | |
| k_dfb = ttl.make_dataflow_buffer_like(K, shape=(1, HEAD_TILES), buffer_factor=2) | |
| v_dfb = ttl.make_dataflow_buffer_like(V, shape=(1, HEAD_TILES), buffer_factor=2) | |
| attn_sc_dfb = ttl.make_dataflow_buffer_like(scale_tile, shape=(1, 1), buffer_factor=1) | |
| attn_scaler_dfb = ttl.make_dataflow_buffer_like(attn_scaler, shape=(1, 1), buffer_factor=1) | |
| ninf_dfb = ttl.make_dataflow_buffer_like(neg_inf_tile, shape=(1, 1), buffer_factor=1) | |
| zero_dfb = ttl.make_dataflow_buffer_like(zero_tile, shape=(1, 1), buffer_factor=1) | |
| zh_dfb = ttl.make_dataflow_buffer_like(zero_head, shape=(1, HEAD_TILES), buffer_factor=1) | |
| mask_dfb = ttl.make_dataflow_buffer_like(causal_mask, shape=(1, 1), buffer_factor=2) | |
| kt_dfb = ttl.make_dataflow_buffer_like(K, shape=(HEAD_TILES, 1), buffer_factor=2) | |
| qk_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, 1), buffer_factor=2) | |
| scaled_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, 1), buffer_factor=2) | |
| cm_dfb = ttl.make_dataflow_buffer_like(attn_scaler, shape=(1, 1), buffer_factor=2) | |
| m_new_dfb = ttl.make_dataflow_buffer_like(attn_scaler, shape=(1, 1), buffer_factor=2) | |
| alpha_dfb = ttl.make_dataflow_buffer_like(attn_scaler, shape=(1, 1), buffer_factor=2) | |
| alpha_bc_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2) | |
| exp_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, 1), buffer_factor=2) | |
| cs_dfb = ttl.make_dataflow_buffer_like(attn_scaler, shape=(1, 1), buffer_factor=2) | |
| co_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2) | |
| m_dfb = ttl.make_dataflow_buffer_like(attn_scaler, shape=(1, 1), buffer_factor=2) | |
| l_dfb = ttl.make_dataflow_buffer_like(attn_scaler, shape=(1, 1), buffer_factor=2) | |
| o_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2) | |
| l_bc_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2) | |
| out_dfb = ttl.make_dataflow_buffer_like(out, shape=(1, HEAD_TILES), buffer_factor=2) | |
| m_out_dfb = ttl.make_dataflow_buffer_like(m_out, shape=(1, 1), buffer_factor=2) | |
| l_out_dfb = ttl.make_dataflow_buffer_like(l_out, shape=(1, 1), buffer_factor=2) | |
| # === Wire it up: rmsnorm → attention, no DRAM in between === | |
| # Datamovement: rmsnorm reads, attention reads K/V/constants | |
| rmsnorm_dm_read(x_dfb, sc_dfb, ms_dfb) | |
| attn_dm_read(q_dfb, k_dfb, v_dfb, attn_sc_dfb, attn_scaler_dfb, | |
| ninf_dfb, zero_dfb, zh_dfb, mask_dfb) | |
| # Compute: rmsnorm produces into q_dfb, attention consumes from q_dfb | |
| rmsnorm_compute(x_dfb, sc_dfb, ms_dfb, sq_dfb, red_dfb, acc_dfb, | |
| bcast_dfb, rsq_dfb, q_dfb) # note: q_dfb is rmsnorm's "out_dfb" | |
| attn_compute(q_dfb, k_dfb, v_dfb, attn_sc_dfb, attn_scaler_dfb, | |
| ninf_dfb, zero_dfb, zh_dfb, mask_dfb, | |
| kt_dfb, qk_dfb, scaled_dfb, cm_dfb, m_new_dfb, alpha_dfb, | |
| alpha_bc_dfb, exp_dfb, cs_dfb, co_dfb, | |
| m_dfb, l_dfb, o_dfb, l_bc_dfb, out_dfb, m_out_dfb, l_out_dfb) | |
| # Only attention writes to DRAM | |
| attn_dm_write(out_dfb, m_out_dfb, l_out_dfb) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment