Last active
April 13, 2026 17:17
-
-
Save zoecarver/752427e20735af04144c7533131adc82 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | |
| # | |
| # SPDX-License-Identifier: Apache-2.0 | |
| """ | |
| Sweep benchmark: mcast matmul, balanced matmul, balanced matmul+relu vs ttnn. | |
| mcast: A+B mcast both on dm_read, grid="auto" (works at any size) | |
| balanced: A on dm_read, B on dm_write (two-NoC), 1 block/core | |
| relu: balanced matmul with fused relu on last K iteration | |
| Small shapes also test L1 interleaved inputs. | |
| """ | |
| import ttnn | |
| import ttl | |
| import sys | |
| import time | |
| sys.path.insert(0, "/tmp") | |
| TILE = 32 | |
| BLOCK_M = 8 | |
| BLOCK_N = 8 | |
| BLOCK_K = 8 | |
| BLOCK_SIZE = BLOCK_M * TILE # 256 | |
| DRAM = ttnn.DRAM_MEMORY_CONFIG | |
| L1 = ttnn.L1_MEMORY_CONFIG | |
| # Toggle accumulation precision for both TTL and TTNN kernels. | |
| # True = f32 dest acc (HiFi4) -- higher accuracy, slower | |
| # False = bf16 dest acc (HiFi2) -- lower accuracy, faster | |
| FP32_ACC = True | |
| TTNN_COMPUTE_CONFIG = ttnn.WormholeComputeKernelConfig( | |
| math_fidelity=ttnn.MathFidelity.HiFi4 if FP32_ACC else ttnn.MathFidelity.HiFi2, | |
| fp32_dest_acc_en=FP32_ACC, | |
| packer_l1_acc=True, | |
| ) | |
| MAX_GRID_N = 13 # Wormhole max grid columns | |
| MAX_GRID_M = 10 # Wormhole max grid rows | |
| def _even_split(n_blocks, max_grid): | |
| """Pick blocks_per_node that divides n_blocks evenly. | |
| Avoids pipe deadlock when remainder cores skip iterations that src cores don't.""" | |
| bpn = -(-n_blocks // max_grid) | |
| while n_blocks % bpn != 0: | |
| bpn += 1 | |
| return bpn, n_blocks // bpn | |
| # --------------------------------------------------------------------------- | |
| # Kernel: mcast matmul (both A+B on dm_read) | |
| # --------------------------------------------------------------------------- | |
| def make_mcast_kernel(M_DIM, K_DIM, N_DIM): | |
| M_BLOCKS = M_DIM // BLOCK_SIZE | |
| N_BLOCKS = N_DIM // BLOCK_SIZE | |
| K_BLOCKS = K_DIM // BLOCK_SIZE | |
| M_BPN, NUM_ROWS = _even_split(M_BLOCKS, MAX_GRID_M) | |
| N_BPN, NUM_COLS = _even_split(N_BLOCKS, MAX_GRID_N) | |
| @ttl.operation(grid=(NUM_COLS, NUM_ROWS), fp32_dest_acc_en=FP32_ACC) | |
| def mcast_matmul(a, w, out): | |
| m_blocks_per_node = M_BPN | |
| n_blocks_per_node = N_BPN | |
| a_pipes = [ | |
| ttl.Pipe(src=(0, row), dst=(slice(0, NUM_COLS), row)) | |
| for row in range(NUM_ROWS) | |
| ] | |
| mcast_a_net = ttl.PipeNet(a_pipes) | |
| b_pipes = [ | |
| ttl.Pipe(src=(col, 0), dst=(col, slice(0, NUM_ROWS))) | |
| for col in range(NUM_COLS) | |
| ] | |
| mcast_b_net = ttl.PipeNet(b_pipes) | |
| a_cb = ttl.make_dataflow_buffer_like(a, shape=(BLOCK_M, BLOCK_K), block_count=2) | |
| b_cb = ttl.make_dataflow_buffer_like(w, shape=(BLOCK_K, BLOCK_N), block_count=2) | |
| out_cb = ttl.make_dataflow_buffer_like(out, shape=(BLOCK_M, BLOCK_N), block_count=2) | |
| @ttl.compute() | |
| def compute(): | |
| node_n, node_m = ttl.node(dims=2) | |
| for local_mb in range(m_blocks_per_node): | |
| mb = node_m * m_blocks_per_node + local_mb | |
| for local_nb in range(n_blocks_per_node): | |
| nb = node_n * n_blocks_per_node + local_nb | |
| out_blk = out_cb.reserve() | |
| for _ in range(K_BLOCKS): | |
| a_blk = a_cb.wait() | |
| b_blk = b_cb.wait() | |
| out_blk += a_blk @ b_blk | |
| a_blk.pop() | |
| b_blk.pop() | |
| out_blk.push() | |
| @ttl.datamovement() | |
| def dm_read(): | |
| node_n, node_m = ttl.node(dims=2) | |
| for local_mb in range(m_blocks_per_node): | |
| mb = node_m * m_blocks_per_node + local_mb | |
| mr = mb * BLOCK_M | |
| for local_nb in range(n_blocks_per_node): | |
| nb = node_n * n_blocks_per_node + local_nb | |
| nc = nb * BLOCK_N | |
| for kb in range(K_BLOCKS): | |
| kc = kb * BLOCK_K | |
| with a_cb.reserve() as a_blk: | |
| mcast_a_net.if_src(lambda pipe: ( | |
| ttl.copy(a[mr:mr + BLOCK_M, kc:kc + BLOCK_K], a_blk).wait(), | |
| ttl.copy(a_blk, pipe).wait(), | |
| )) | |
| mcast_a_net.if_dst(lambda pipe: ( | |
| ttl.copy(pipe, a_blk).wait(), | |
| )) | |
| with b_cb.reserve() as b_blk: | |
| mcast_b_net.if_src(lambda pipe: ( | |
| ttl.copy(w[kc:kc + BLOCK_K, nc:nc + BLOCK_N], b_blk).wait(), | |
| ttl.copy(b_blk, pipe).wait(), | |
| )) | |
| mcast_b_net.if_dst(lambda pipe: ( | |
| ttl.copy(pipe, b_blk).wait(), | |
| )) | |
| @ttl.datamovement() | |
| def dm_write(): | |
| node_n, node_m = ttl.node(dims=2) | |
| for local_mb in range(m_blocks_per_node): | |
| mb = node_m * m_blocks_per_node + local_mb | |
| mr = mb * BLOCK_M | |
| for local_nb in range(n_blocks_per_node): | |
| nb = node_n * n_blocks_per_node + local_nb | |
| nc = nb * BLOCK_N | |
| with out_cb.wait() as out_blk: | |
| ttl.copy(out_blk, out[mr:mr + BLOCK_M, nc:nc + BLOCK_N]).wait() | |
| return mcast_matmul | |
| # --------------------------------------------------------------------------- | |
| # Kernel: balanced matmul (A on dm_read, B on dm_write, grid="auto") | |
| # --------------------------------------------------------------------------- | |
| def make_balanced_kernel(M_DIM, K_DIM, N_DIM): | |
| M_BLOCKS = M_DIM // BLOCK_SIZE | |
| N_BLOCKS = N_DIM // BLOCK_SIZE | |
| K_BLOCKS = K_DIM // BLOCK_SIZE | |
| M_BPN, NUM_ROWS = _even_split(M_BLOCKS, MAX_GRID_M) | |
| N_BPN, NUM_COLS = _even_split(N_BLOCKS, MAX_GRID_N) | |
| @ttl.operation(grid=(NUM_COLS, NUM_ROWS), fp32_dest_acc_en=FP32_ACC) | |
| def balanced_matmul(a, w, out): | |
| m_blocks_per_node = M_BPN | |
| n_blocks_per_node = N_BPN | |
| a_pipes = [ | |
| ttl.Pipe(src=(0, row), dst=(slice(0, NUM_COLS), row)) | |
| for row in range(NUM_ROWS) | |
| ] | |
| mcast_a_net = ttl.PipeNet(a_pipes) | |
| b_pipes = [ | |
| ttl.Pipe(src=(col, 0), dst=(col, slice(0, NUM_ROWS))) | |
| for col in range(NUM_COLS) | |
| ] | |
| mcast_b_net = ttl.PipeNet(b_pipes) | |
| a_cb = ttl.make_dataflow_buffer_like(a, shape=(BLOCK_M, BLOCK_K), block_count=2) | |
| b_cb = ttl.make_dataflow_buffer_like(w, shape=(BLOCK_K, BLOCK_N), block_count=2) | |
| out_cb = ttl.make_dataflow_buffer_like(out, shape=(BLOCK_M, BLOCK_N), block_count=2) | |
| @ttl.compute() | |
| def compute(): | |
| node_n, node_m = ttl.node(dims=2) | |
| for local_mb in range(m_blocks_per_node): | |
| mb = node_m * m_blocks_per_node + local_mb | |
| for local_nb in range(n_blocks_per_node): | |
| nb = node_n * n_blocks_per_node + local_nb | |
| out_blk = out_cb.reserve() | |
| for _ in range(K_BLOCKS): | |
| a_blk = a_cb.wait() | |
| b_blk = b_cb.wait() | |
| out_blk += a_blk @ b_blk | |
| a_blk.pop() | |
| b_blk.pop() | |
| out_blk.push() | |
| @ttl.datamovement() | |
| def dm_read(): | |
| node_n, node_m = ttl.node(dims=2) | |
| for local_mb in range(m_blocks_per_node): | |
| mb = node_m * m_blocks_per_node + local_mb | |
| mr = mb * BLOCK_M | |
| for local_nb in range(n_blocks_per_node): | |
| nb = node_n * n_blocks_per_node + local_nb | |
| for kb in range(K_BLOCKS): | |
| kc = kb * BLOCK_K | |
| with a_cb.reserve() as a_blk: | |
| mcast_a_net.if_src(lambda pipe: ( | |
| ttl.copy(a[mr:mr + BLOCK_M, kc:kc + BLOCK_K], a_blk).wait(), | |
| ttl.copy(a_blk, pipe).wait(), | |
| )) | |
| mcast_a_net.if_dst(lambda pipe: ( | |
| ttl.copy(pipe, a_blk).wait(), | |
| )) | |
| @ttl.datamovement() | |
| def dm_write(): | |
| node_n, node_m = ttl.node(dims=2) | |
| for local_mb in range(m_blocks_per_node): | |
| mb = node_m * m_blocks_per_node + local_mb | |
| mr = mb * BLOCK_M | |
| for local_nb in range(n_blocks_per_node): | |
| nb = node_n * n_blocks_per_node + local_nb | |
| nc = nb * BLOCK_N | |
| for kb in range(K_BLOCKS): | |
| kc = kb * BLOCK_K | |
| with b_cb.reserve() as b_blk: | |
| mcast_b_net.if_src(lambda pipe: ( | |
| ttl.copy(w[kc:kc + BLOCK_K, nc:nc + BLOCK_N], b_blk).wait(), | |
| ttl.copy(b_blk, pipe).wait(), | |
| )) | |
| mcast_b_net.if_dst(lambda pipe: ( | |
| ttl.copy(pipe, b_blk).wait(), | |
| )) | |
| with out_cb.wait() as out_blk: | |
| ttl.copy(out_blk, out[mr:mr + BLOCK_M, nc:nc + BLOCK_N]).wait() | |
| return balanced_matmul | |
| # --------------------------------------------------------------------------- | |
| # Kernel: balanced matmul + relu (fused activation on last K iteration) | |
| # --------------------------------------------------------------------------- | |
| def make_balanced_relu_kernel(M_DIM, K_DIM, N_DIM): | |
| M_BLOCKS = M_DIM // BLOCK_SIZE | |
| N_BLOCKS = N_DIM // BLOCK_SIZE | |
| K_BLOCKS = K_DIM // BLOCK_SIZE | |
| M_BPN, NUM_ROWS = _even_split(M_BLOCKS, MAX_GRID_M) | |
| N_BPN, NUM_COLS = _even_split(N_BLOCKS, MAX_GRID_N) | |
| @ttl.operation(grid=(NUM_COLS, NUM_ROWS), fp32_dest_acc_en=FP32_ACC) | |
| def balanced_matmul_relu(a, w, out): | |
| m_blocks_per_node = M_BPN | |
| n_blocks_per_node = N_BPN | |
| a_pipes = [ | |
| ttl.Pipe(src=(0, row), dst=(slice(0, NUM_COLS), row)) | |
| for row in range(NUM_ROWS) | |
| ] | |
| mcast_a_net = ttl.PipeNet(a_pipes) | |
| b_pipes = [ | |
| ttl.Pipe(src=(col, 0), dst=(col, slice(0, NUM_ROWS))) | |
| for col in range(NUM_COLS) | |
| ] | |
| mcast_b_net = ttl.PipeNet(b_pipes) | |
| a_cb = ttl.make_dataflow_buffer_like(a, shape=(BLOCK_M, BLOCK_K), block_count=2) | |
| b_cb = ttl.make_dataflow_buffer_like(w, shape=(BLOCK_K, BLOCK_N), block_count=2) | |
| acc_cb = ttl.make_dataflow_buffer_like(out, shape=(BLOCK_M, BLOCK_N), block_count=2) | |
| out_cb = ttl.make_dataflow_buffer_like(out, shape=(BLOCK_M, BLOCK_N), block_count=1) | |
| @ttl.compute() | |
| def compute(): | |
| node_n, node_m = ttl.node(dims=2) | |
| for local_mb in range(m_blocks_per_node): | |
| mb = node_m * m_blocks_per_node + local_mb | |
| for local_nb in range(n_blocks_per_node): | |
| nb = node_n * n_blocks_per_node + local_nb | |
| with acc_cb.reserve() as init: | |
| init.store(ttl.math.fill(init, 0)) | |
| for kb in range(K_BLOCKS): | |
| with ( | |
| a_cb.wait() as a_blk, | |
| b_cb.wait() as b_blk, | |
| acc_cb.wait() as last, | |
| acc_cb.reserve() as next, | |
| ): | |
| if kb < K_BLOCKS - 1: | |
| next.store(last + a_blk @ b_blk) | |
| else: | |
| next.store(ttl.math.relu(last + a_blk @ b_blk)) | |
| with acc_cb.wait() as result, out_cb.reserve() as o: | |
| o.store(result) | |
| @ttl.datamovement() | |
| def dm_read(): | |
| node_n, node_m = ttl.node(dims=2) | |
| for local_mb in range(m_blocks_per_node): | |
| mb = node_m * m_blocks_per_node + local_mb | |
| mr = mb * BLOCK_M | |
| for local_nb in range(n_blocks_per_node): | |
| nb = node_n * n_blocks_per_node + local_nb | |
| for kb in range(K_BLOCKS): | |
| kc = kb * BLOCK_K | |
| with a_cb.reserve() as a_blk: | |
| mcast_a_net.if_src(lambda pipe: ( | |
| ttl.copy(a[mr:mr + BLOCK_M, kc:kc + BLOCK_K], a_blk).wait(), | |
| ttl.copy(a_blk, pipe).wait(), | |
| )) | |
| mcast_a_net.if_dst(lambda pipe: ( | |
| ttl.copy(pipe, a_blk).wait(), | |
| )) | |
| @ttl.datamovement() | |
| def dm_write(): | |
| node_n, node_m = ttl.node(dims=2) | |
| for local_mb in range(m_blocks_per_node): | |
| mb = node_m * m_blocks_per_node + local_mb | |
| mr = mb * BLOCK_M | |
| for local_nb in range(n_blocks_per_node): | |
| nb = node_n * n_blocks_per_node + local_nb | |
| nc = nb * BLOCK_N | |
| for kb in range(K_BLOCKS): | |
| kc = kb * BLOCK_K | |
| with b_cb.reserve() as b_blk: | |
| mcast_b_net.if_src(lambda pipe: ( | |
| ttl.copy(w[kc:kc + BLOCK_K, nc:nc + BLOCK_N], b_blk).wait(), | |
| ttl.copy(b_blk, pipe).wait(), | |
| )) | |
| mcast_b_net.if_dst(lambda pipe: ( | |
| ttl.copy(pipe, b_blk).wait(), | |
| )) | |
| with out_cb.wait() as out_blk: | |
| ttl.copy(out_blk, out[mr:mr + BLOCK_M, nc:nc + BLOCK_N]).wait() | |
| return balanced_matmul_relu | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def bench(fn, *args, warmup=2, iters=5): | |
| for _ in range(warmup): | |
| fn(*args) | |
| ttnn.synchronize_device(args[0].device()) | |
| times = [] | |
| for _ in range(iters): | |
| t0 = time.perf_counter() | |
| fn(*args) | |
| ttnn.synchronize_device(args[0].device()) | |
| times.append((time.perf_counter() - t0) * 1000) | |
| return min(times), sum(times) / len(times) | |
| def to_device(t, device, memory_config): | |
| return ttnn.from_torch(t.contiguous(), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, | |
| device=device, memory_config=memory_config) | |
| def make_tensors(M, K, N, device, memory_config): | |
| import torch | |
| a = to_device(torch.randn(M, K, dtype=torch.bfloat16) * 0.02, device, memory_config) | |
| w = to_device(torch.randn(K, N, dtype=torch.bfloat16) * 0.02, device, memory_config) | |
| out = to_device(torch.zeros(M, N, dtype=torch.bfloat16), device, memory_config) | |
| return a, w, out | |
| # --------------------------------------------------------------------------- | |
| # Shapes | |
| # --------------------------------------------------------------------------- | |
| SHAPES_SMALL = [ | |
| (1024, 1024, 1024, "1k^3"), | |
| (1024, 2048, 1024, "1k x 2k x 1k"), | |
| (2048, 2048, 2048, "2k^3"), | |
| (2048, 4096, 2048, "2k x 4k x 2k"), | |
| (2560, 2048, 3072, "2.5k x 2k x 3k"), | |
| ] | |
| SHAPES_LARGE = [ | |
| (2048, 8192, 2048, "2k x 8k x 2k (long K)"), | |
| (2560, 4096, 3072, "2.5k x 4k x 3k"), | |
| (2560, 8192, 3072, "2.5k x 8k x 3k (120 cores)"), | |
| (2560, 8192, 3328, "2.5k x 8k x 3.3k (130 cores)"), | |
| (1024, 16384, 2560, "1k x 16k x 2.5k (tall K)"), | |
| (4096, 4096, 4096, "4k^3"), | |
| (4096, 8192, 4096, "4k x 8k x 4k"), | |
| (8192, 8192, 8192, "8k^3"), | |
| (10240, 8192, 13312, "10k x 8k x 13k (130 cores, 4x4)"), | |
| ] | |
| def dealloc(*tensors): | |
| for t in tensors: | |
| ttnn.deallocate(t) | |
| def bench_and_cleanup(fn, *tensors, device): | |
| best, avg = bench(fn, *tensors) | |
| dealloc(*tensors) | |
| ttnn.synchronize_device(device) | |
| time.sleep(0.01) | |
| return best, avg | |
| def run_matmul_shape(M, K, N, label, device, mem_config=DRAM): | |
| mc_label = "L1" if mem_config is not DRAM else "" | |
| mcast_fn = make_mcast_kernel(M, K, N) | |
| a, w, o = make_tensors(M, K, N, device, mem_config) | |
| mc_best, mc_avg = bench_and_cleanup(mcast_fn, a, w, o, device=device) | |
| bal_fn = make_balanced_kernel(M, K, N) | |
| a2, w2, o2 = make_tensors(M, K, N, device, mem_config) | |
| bal_best, bal_avg = bench_and_cleanup(bal_fn, a2, w2, o2, device=device) | |
| def ttnn_mm(a, w): | |
| return ttnn.matmul(a, w, compute_kernel_config=TTNN_COMPUTE_CONFIG) | |
| a3, w3, _ = make_tensors(M, K, N, device, mem_config) | |
| ttnn_best, ttnn_avg = bench_and_cleanup(ttnn_mm, a3, w3, device=device) | |
| mc_r = f"{mc_best / ttnn_best:.2f}x" | |
| bal_r = f"{bal_best / ttnn_best:.2f}x" | |
| lbl = f"{label} {mc_label}" if mc_label else label | |
| print( | |
| f"{lbl:<32} " | |
| f"{mc_best:>9.2f}ms {mc_avg:>8.2f}ms " | |
| f"{bal_best:>9.2f}ms {bal_avg:>8.2f}ms " | |
| f"{ttnn_best:>9.2f}ms {ttnn_avg:>8.2f}ms " | |
| f"{mc_r:>8} {bal_r:>9}" | |
| ) | |
| return {"label": lbl, "flops": 2*M*K*N, "mcast": mc_best, "balanced": bal_best, "ttnn": ttnn_best} | |
| def run_relu_shape(M, K, N, label, device, mem_config=DRAM): | |
| mc_label = "L1" if mem_config is not DRAM else "" | |
| relu_fn = make_balanced_relu_kernel(M, K, N) | |
| a, w, o = make_tensors(M, K, N, device, mem_config) | |
| ttl_best, ttl_avg = bench_and_cleanup(relu_fn, a, w, o, device=device) | |
| def linear_relu(a, w): | |
| return ttnn.linear(a, w, activation="relu", compute_kernel_config=TTNN_COMPUTE_CONFIG) | |
| a2, w2, _ = make_tensors(M, K, N, device, mem_config) | |
| lin_best, lin_avg = bench_and_cleanup(linear_relu, a2, w2, device=device) | |
| def mm_relu(a, w): | |
| return ttnn.relu(ttnn.matmul(a, w, compute_kernel_config=TTNN_COMPUTE_CONFIG)) | |
| a3, w3, _ = make_tensors(M, K, N, device, mem_config) | |
| sep_best, sep_avg = bench_and_cleanup(mm_relu, a3, w3, device=device) | |
| r_lin = f"{ttl_best / lin_best:.2f}x" | |
| r_sep = f"{ttl_best / sep_best:.2f}x" | |
| lbl = f"{label} {mc_label}" if mc_label else label | |
| print( | |
| f"{lbl:<32} " | |
| f"{ttl_best:>9.2f}ms {ttl_avg:>8.2f}ms " | |
| f"{lin_best:>10.2f}ms {lin_avg:>8.2f}ms " | |
| f"{sep_best:>9.2f}ms {sep_avg:>8.2f}ms " | |
| f"{r_lin:>8} {r_sep:>8}" | |
| ) | |
| return {"label": lbl, "flops": 2*M*K*N, "ttl_relu": ttl_best, "ttnn_linear": lin_best, "mm_relu": sep_best} | |
| def save_plot(matmul_data, relu_data, path="/tmp/bench_matmul_sweep.png"): | |
| """Generate performance plots with shape labels on x-axis.""" | |
| try: | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| except ImportError: | |
| print("matplotlib not available, skipping plot") | |
| return | |
| fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(16, 18)) | |
| # --- Top: Matmul sweep --- | |
| # Group by shape label (without L1 suffix), show DRAM and L1 variants | |
| shape_order = [] | |
| series = { | |
| "ttnn DRAM": {}, "ttnn L1": {}, | |
| "ttl balanced DRAM": {}, "ttl balanced L1": {}, | |
| "ttl mcast DRAM": {}, "ttl mcast L1": {}, | |
| } | |
| for d in matmul_data: | |
| lbl = d["label"].strip() | |
| is_l1 = lbl.endswith("L1") | |
| base = lbl.replace(" L1", "").strip() if is_l1 else lbl | |
| suffix = " L1" if is_l1 else " DRAM" | |
| if base not in shape_order: | |
| shape_order.append(base) | |
| if d.get("ttnn") is not None: | |
| series[f"ttnn{suffix}"][base] = d["ttnn"] | |
| if d.get("balanced") is not None: | |
| series[f"ttl balanced{suffix}"][base] = d["balanced"] | |
| if d.get("mcast") is not None: | |
| series[f"ttl mcast{suffix}"][base] = d["mcast"] | |
| x = np.arange(len(shape_order)) | |
| colors = { | |
| "ttnn DRAM": "#1f77b4", "ttnn L1": "#aec7e8", | |
| "ttl balanced DRAM": "#2ca02c", "ttl balanced L1": "#98df8a", | |
| "ttl mcast DRAM": "#ff7f0e", "ttl mcast L1": "#ffbb78", | |
| } | |
| markers = { | |
| "ttnn DRAM": "s", "ttnn L1": "s", | |
| "ttl balanced DRAM": "o", "ttl balanced L1": "o", | |
| "ttl mcast DRAM": "^", "ttl mcast L1": "^", | |
| } | |
| linestyles = {k: "-" if "DRAM" in k else "--" for k in series} | |
| for name, data in series.items(): | |
| if not data: | |
| continue | |
| vals = [data.get(s) for s in shape_order] | |
| xs = [i for i, v in enumerate(vals) if v is not None] | |
| ys = [v for v in vals if v is not None] | |
| if ys: | |
| ax1.plot(xs, ys, marker=markers[name], color=colors[name], | |
| linestyle=linestyles[name], label=name, linewidth=1.5, | |
| markersize=5, alpha=0.85) | |
| ax1.set_xticks(x) | |
| ax1.set_xticklabels(shape_order, rotation=45, ha="right", fontsize=8) | |
| ax1.set_ylabel("Time (ms)") | |
| ax1.set_title("Matmul: TTL vs TTNN (lower is better)") | |
| ax1.legend(fontsize=7, ncol=3, loc="upper left") | |
| ax1.set_yscale("log") | |
| ax1.grid(True, alpha=0.3, axis="y") | |
| # --- Bottom: Fused relu --- | |
| relu_shapes = [] | |
| relu_series = { | |
| "ttl relu DRAM": {}, "ttl relu L1": {}, | |
| "ttnn.linear DRAM": {}, "ttnn.linear L1": {}, | |
| "ttnn.matmul+ttnn.relu DRAM": {}, "ttnn.matmul+ttnn.relu L1": {}, | |
| } | |
| for d in relu_data: | |
| lbl = d["label"].strip() | |
| is_l1 = lbl.endswith("L1") | |
| base = lbl.replace(" L1", "").strip() if is_l1 else lbl | |
| suffix = " L1" if is_l1 else " DRAM" | |
| if base not in relu_shapes: | |
| relu_shapes.append(base) | |
| relu_series[f"ttl relu{suffix}"][base] = d["ttl_relu"] | |
| relu_series[f"ttnn.linear{suffix}"][base] = d["ttnn_linear"] | |
| relu_series[f"ttnn.matmul+ttnn.relu{suffix}"][base] = d["mm_relu"] | |
| if relu_shapes: | |
| x2 = np.arange(len(relu_shapes)) | |
| relu_colors = { | |
| "ttl relu DRAM": "#2ca02c", "ttl relu L1": "#98df8a", | |
| "ttnn.linear DRAM": "#1f77b4", "ttnn.linear L1": "#aec7e8", | |
| "ttnn.matmul+ttnn.relu DRAM": "#d62728", "ttnn.matmul+ttnn.relu L1": "#ff9896", | |
| } | |
| relu_markers = { | |
| "ttl relu DRAM": "o", "ttl relu L1": "o", | |
| "ttnn.linear DRAM": "s", "ttnn.linear L1": "s", | |
| "ttnn.matmul+ttnn.relu DRAM": "^", "ttnn.matmul+ttnn.relu L1": "^", | |
| } | |
| relu_ls = {k: "-" if "DRAM" in k else "--" for k in relu_series} | |
| for name, data in relu_series.items(): | |
| if not data: | |
| continue | |
| vals = [data.get(s) for s in relu_shapes] | |
| xs = [i for i, v in enumerate(vals) if v is not None] | |
| ys = [v for v in vals if v is not None] | |
| if ys: | |
| ax2.plot(xs, ys, marker=relu_markers[name], color=relu_colors[name], | |
| linestyle=relu_ls[name], label=name, linewidth=1.5, | |
| markersize=5, alpha=0.85) | |
| ax2.set_xticks(x2) | |
| ax2.set_xticklabels(relu_shapes, rotation=45, ha="right", fontsize=8) | |
| ax2.set_ylabel("Time (ms)") | |
| ax2.set_title("Fused Matmul+ReLU: TTL vs TTNN (lower is better)") | |
| ax2.legend(fontsize=7, ncol=3, loc="upper left") | |
| ax2.set_yscale("log") | |
| ax2.grid(True, alpha=0.3, axis="y") | |
| # --- Bottom: Ratio plot (TTL / TTNN baseline), all shapes merged --- | |
| # Matmul lines: ratio vs ttnn DRAM | |
| # Relu lines: ratio vs ttnn.matmul+ttnn.relu DRAM | |
| all_shapes = [] | |
| ratio_series = {} | |
| # Matmul ratios: all TTL variants / ttnn DRAM | |
| ttnn_dram = series.get("ttnn DRAM", {}) | |
| for ttl_key in ("ttl mcast DRAM", "ttl mcast L1", "ttl balanced DRAM", "ttl balanced L1"): | |
| ttl_data = series.get(ttl_key, {}) | |
| if not ttl_data: | |
| continue | |
| ratios = {} | |
| for shape in shape_order: | |
| ttl_val = ttl_data.get(shape) | |
| base_val = ttnn_dram.get(shape) | |
| if ttl_val is not None and base_val is not None: | |
| ratios[shape] = ttl_val / base_val | |
| if shape not in all_shapes: | |
| all_shapes.append(shape) | |
| if ratios: | |
| ratio_series[ttl_key] = ratios | |
| # Relu ratios: all TTL relu variants / ttnn.matmul+ttnn.relu DRAM | |
| mm_relu_dram = relu_series.get("ttnn.matmul+ttnn.relu DRAM", {}) | |
| for ttl_key in ("ttl relu DRAM", "ttl relu L1"): | |
| ttl_data = relu_series.get(ttl_key, {}) | |
| if not ttl_data: | |
| continue | |
| ratios = {} | |
| for shape in relu_shapes: | |
| ttl_val = ttl_data.get(shape) | |
| base_val = mm_relu_dram.get(shape) | |
| if ttl_val is not None and base_val is not None: | |
| ratios[shape] = ttl_val / base_val | |
| if shape not in all_shapes: | |
| all_shapes.append(shape) | |
| if ratios: | |
| ratio_series[ttl_key] = ratios | |
| if all_shapes: | |
| x3 = np.arange(len(all_shapes)) | |
| ratio_colors = { | |
| "ttl mcast DRAM": "#ff7f0e", "ttl mcast L1": "#ffbb78", | |
| "ttl balanced DRAM": "#2ca02c", "ttl balanced L1": "#98df8a", | |
| "ttl relu DRAM": "#9467bd", "ttl relu L1": "#c5b0d5", | |
| } | |
| ratio_markers = { | |
| "ttl mcast DRAM": "^", "ttl mcast L1": "^", | |
| "ttl balanced DRAM": "o", "ttl balanced L1": "o", | |
| "ttl relu DRAM": "s", "ttl relu L1": "s", | |
| } | |
| ratio_ls = {k: "-" if "DRAM" in k else "--" for k in ratio_colors} | |
| for name, data in ratio_series.items(): | |
| if not data: | |
| continue | |
| vals = [data.get(s) for s in all_shapes] | |
| xs = [i for i, v in enumerate(vals) if v is not None] | |
| ys = [v for v in vals if v is not None] | |
| if ys: | |
| ax3.plot(xs, ys, marker=ratio_markers.get(name, "o"), | |
| color=ratio_colors.get(name, "#333333"), | |
| linestyle=ratio_ls.get(name, "-"), label=name, | |
| linewidth=1.5, markersize=5, alpha=0.85) | |
| ax3.axhline(y=1.0, color="black", linestyle=":", linewidth=1, alpha=0.7) | |
| ax3.set_xticks(x3) | |
| ax3.set_xticklabels(all_shapes, rotation=45, ha="right", fontsize=8) | |
| ax3.set_ylabel("Ratio (TTL / TTNN)") | |
| ax3.set_title("TTL / TTNN Ratio (matmul vs ttnn.matmul DRAM, relu vs ttnn.matmul+ttnn.relu DRAM) -- below 1.0 = TTL wins") | |
| ax3.legend(fontsize=7, ncol=4, loc="upper right") | |
| ax3.grid(True, alpha=0.3, axis="y") | |
| plt.tight_layout() | |
| plt.savefig(path, dpi=150) | |
| print(f"Plot saved to {path}") | |
| if __name__ == "__main__": | |
| import torch | |
| torch.manual_seed(42) | |
| device = ttnn.open_device(device_id=0) | |
| matmul_data = [] | |
| relu_data = [] | |
| try: | |
| header = ( | |
| f"{'Shape':<32} " | |
| f"{'mcast best':>11} {'mcast avg':>10} " | |
| f"{'bal best':>11} {'bal avg':>10} " | |
| f"{'ttnn best':>11} {'ttnn avg':>10} " | |
| f"{'mc/ttnn':>8} {'bal/ttnn':>9}" | |
| ) | |
| sep = "-" * len(header) | |
| # --- Part 1: Matmul sweep --- | |
| acc_label = "f32" if FP32_ACC else "bf16" | |
| print(f"=== Mcast Matmul Sweep Benchmark (acc={acc_label}) ===") | |
| print(f"Block: {BLOCK_M}x{BLOCK_N}x{BLOCK_K}, tile: {TILE}") | |
| print() | |
| print(header) | |
| print(sep) | |
| for M, K, N, label in SHAPES_SMALL: | |
| matmul_data.append(run_matmul_shape(M, K, N, label, device)) | |
| matmul_data.append(run_matmul_shape(M, K, N, label, device, mem_config=L1)) | |
| for M, K, N, label in SHAPES_LARGE: | |
| matmul_data.append(run_matmul_shape(M, K, N, label, device)) | |
| print(sep) | |
| # --- Part 2: Fused matmul+relu (balanced only) --- | |
| print() | |
| print("=== Fused Matmul+ReLU Benchmark (balanced, 1 block/core) ===") | |
| relu_header = ( | |
| f"{'Shape':<32} " | |
| f"{'ttl relu':>11} {'ttl avg':>10} " | |
| f"{'ttnn.linear':>12} {'lin avg':>10} " | |
| f"{'ttnn.matmul+ttnn.relu':>22} {'sep avg':>10} " | |
| f"{'ttl/lin':>8} {'ttl/sep':>8}" | |
| ) | |
| relu_sep = "-" * len(relu_header) | |
| print(relu_header) | |
| print(relu_sep) | |
| for M, K, N, label in SHAPES_SMALL: | |
| relu_data.append(run_relu_shape(M, K, N, label, device)) | |
| relu_data.append(run_relu_shape(M, K, N, label, device, mem_config=L1)) | |
| for M, K, N, label in SHAPES_LARGE: | |
| relu_data.append(run_relu_shape(M, K, N, label, device)) | |
| print(relu_sep) | |
| # --- Plot --- | |
| save_plot(matmul_data, relu_data) | |
| # --- PCC smoke test --- | |
| import torch | |
| M_t, K_t, N_t = 2560, 8192, 3328 | |
| torch.manual_seed(42) | |
| a_torch = torch.randn(M_t, K_t, dtype=torch.bfloat16) * 0.02 | |
| w_torch = torch.randn(K_t, N_t, dtype=torch.bfloat16) * 0.02 | |
| ref = a_torch.float() @ w_torch.float() | |
| def to_dev(t): | |
| return ttnn.from_torch(t.contiguous(), dtype=ttnn.bfloat16, | |
| layout=ttnn.TILE_LAYOUT, device=device, | |
| memory_config=DRAM) | |
| def calc_pcc(result): | |
| return torch.corrcoef(torch.stack([result.flatten(), ref.flatten()]))[0, 1].item() | |
| # TTL mcast | |
| out_dev = to_dev(torch.zeros(M_t, N_t, dtype=torch.bfloat16)) | |
| mcast_fn = make_mcast_kernel(M_t, K_t, N_t) | |
| mcast_fn(to_dev(a_torch), to_dev(w_torch), out_dev) | |
| ttl_pcc = calc_pcc(ttnn.to_torch(out_dev).float()) | |
| # TTNN | |
| ttnn_out = ttnn.matmul(to_dev(a_torch), to_dev(w_torch), | |
| compute_kernel_config=TTNN_COMPUTE_CONFIG) | |
| ttnn_pcc = calc_pcc(ttnn.to_torch(ttnn_out).float()) | |
| print() | |
| print(f"=== PCC Smoke Test ({M_t}x{K_t}x{N_t}, acc={acc_label}) ===") | |
| print(f" TTL mcast: {ttl_pcc:.6f}") | |
| print(f" TTNN: {ttnn_pcc:.6f}") | |
| assert ttl_pcc > 0.99, f"TTL PCC too low: {ttl_pcc}" | |
| assert ttnn_pcc > 0.99, f"TTNN PCC too low: {ttnn_pcc}" | |
| print(" PASS") | |
| print("=== Done ===") | |
| finally: | |
| ttnn.close_device(device) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment