Skip to content

Instantly share code, notes, and snippets.

@zoecarver
Last active April 13, 2026 17:17
Show Gist options
  • Select an option

  • Save zoecarver/752427e20735af04144c7533131adc82 to your computer and use it in GitHub Desktop.

Select an option

Save zoecarver/752427e20735af04144c7533131adc82 to your computer and use it in GitHub Desktop.
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
"""
Sweep benchmark: mcast matmul, balanced matmul, balanced matmul+relu vs ttnn.
mcast: A+B mcast both on dm_read, grid="auto" (works at any size)
balanced: A on dm_read, B on dm_write (two-NoC), 1 block/core
relu: balanced matmul with fused relu on last K iteration
Small shapes also test L1 interleaved inputs.
"""
import ttnn
import ttl
import sys
import time
sys.path.insert(0, "/tmp")
TILE = 32
BLOCK_M = 8
BLOCK_N = 8
BLOCK_K = 8
BLOCK_SIZE = BLOCK_M * TILE # 256
DRAM = ttnn.DRAM_MEMORY_CONFIG
L1 = ttnn.L1_MEMORY_CONFIG
# Toggle accumulation precision for both TTL and TTNN kernels.
# True = f32 dest acc (HiFi4) -- higher accuracy, slower
# False = bf16 dest acc (HiFi2) -- lower accuracy, faster
FP32_ACC = True
TTNN_COMPUTE_CONFIG = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi4 if FP32_ACC else ttnn.MathFidelity.HiFi2,
fp32_dest_acc_en=FP32_ACC,
packer_l1_acc=True,
)
MAX_GRID_N = 13 # Wormhole max grid columns
MAX_GRID_M = 10 # Wormhole max grid rows
def _even_split(n_blocks, max_grid):
"""Pick blocks_per_node that divides n_blocks evenly.
Avoids pipe deadlock when remainder cores skip iterations that src cores don't."""
bpn = -(-n_blocks // max_grid)
while n_blocks % bpn != 0:
bpn += 1
return bpn, n_blocks // bpn
# ---------------------------------------------------------------------------
# Kernel: mcast matmul (both A+B on dm_read)
# ---------------------------------------------------------------------------
def make_mcast_kernel(M_DIM, K_DIM, N_DIM):
M_BLOCKS = M_DIM // BLOCK_SIZE
N_BLOCKS = N_DIM // BLOCK_SIZE
K_BLOCKS = K_DIM // BLOCK_SIZE
M_BPN, NUM_ROWS = _even_split(M_BLOCKS, MAX_GRID_M)
N_BPN, NUM_COLS = _even_split(N_BLOCKS, MAX_GRID_N)
@ttl.operation(grid=(NUM_COLS, NUM_ROWS), fp32_dest_acc_en=FP32_ACC)
def mcast_matmul(a, w, out):
m_blocks_per_node = M_BPN
n_blocks_per_node = N_BPN
a_pipes = [
ttl.Pipe(src=(0, row), dst=(slice(0, NUM_COLS), row))
for row in range(NUM_ROWS)
]
mcast_a_net = ttl.PipeNet(a_pipes)
b_pipes = [
ttl.Pipe(src=(col, 0), dst=(col, slice(0, NUM_ROWS)))
for col in range(NUM_COLS)
]
mcast_b_net = ttl.PipeNet(b_pipes)
a_cb = ttl.make_dataflow_buffer_like(a, shape=(BLOCK_M, BLOCK_K), block_count=2)
b_cb = ttl.make_dataflow_buffer_like(w, shape=(BLOCK_K, BLOCK_N), block_count=2)
out_cb = ttl.make_dataflow_buffer_like(out, shape=(BLOCK_M, BLOCK_N), block_count=2)
@ttl.compute()
def compute():
node_n, node_m = ttl.node(dims=2)
for local_mb in range(m_blocks_per_node):
mb = node_m * m_blocks_per_node + local_mb
for local_nb in range(n_blocks_per_node):
nb = node_n * n_blocks_per_node + local_nb
out_blk = out_cb.reserve()
for _ in range(K_BLOCKS):
a_blk = a_cb.wait()
b_blk = b_cb.wait()
out_blk += a_blk @ b_blk
a_blk.pop()
b_blk.pop()
out_blk.push()
@ttl.datamovement()
def dm_read():
node_n, node_m = ttl.node(dims=2)
for local_mb in range(m_blocks_per_node):
mb = node_m * m_blocks_per_node + local_mb
mr = mb * BLOCK_M
for local_nb in range(n_blocks_per_node):
nb = node_n * n_blocks_per_node + local_nb
nc = nb * BLOCK_N
for kb in range(K_BLOCKS):
kc = kb * BLOCK_K
with a_cb.reserve() as a_blk:
mcast_a_net.if_src(lambda pipe: (
ttl.copy(a[mr:mr + BLOCK_M, kc:kc + BLOCK_K], a_blk).wait(),
ttl.copy(a_blk, pipe).wait(),
))
mcast_a_net.if_dst(lambda pipe: (
ttl.copy(pipe, a_blk).wait(),
))
with b_cb.reserve() as b_blk:
mcast_b_net.if_src(lambda pipe: (
ttl.copy(w[kc:kc + BLOCK_K, nc:nc + BLOCK_N], b_blk).wait(),
ttl.copy(b_blk, pipe).wait(),
))
mcast_b_net.if_dst(lambda pipe: (
ttl.copy(pipe, b_blk).wait(),
))
@ttl.datamovement()
def dm_write():
node_n, node_m = ttl.node(dims=2)
for local_mb in range(m_blocks_per_node):
mb = node_m * m_blocks_per_node + local_mb
mr = mb * BLOCK_M
for local_nb in range(n_blocks_per_node):
nb = node_n * n_blocks_per_node + local_nb
nc = nb * BLOCK_N
with out_cb.wait() as out_blk:
ttl.copy(out_blk, out[mr:mr + BLOCK_M, nc:nc + BLOCK_N]).wait()
return mcast_matmul
# ---------------------------------------------------------------------------
# Kernel: balanced matmul (A on dm_read, B on dm_write, grid="auto")
# ---------------------------------------------------------------------------
def make_balanced_kernel(M_DIM, K_DIM, N_DIM):
M_BLOCKS = M_DIM // BLOCK_SIZE
N_BLOCKS = N_DIM // BLOCK_SIZE
K_BLOCKS = K_DIM // BLOCK_SIZE
M_BPN, NUM_ROWS = _even_split(M_BLOCKS, MAX_GRID_M)
N_BPN, NUM_COLS = _even_split(N_BLOCKS, MAX_GRID_N)
@ttl.operation(grid=(NUM_COLS, NUM_ROWS), fp32_dest_acc_en=FP32_ACC)
def balanced_matmul(a, w, out):
m_blocks_per_node = M_BPN
n_blocks_per_node = N_BPN
a_pipes = [
ttl.Pipe(src=(0, row), dst=(slice(0, NUM_COLS), row))
for row in range(NUM_ROWS)
]
mcast_a_net = ttl.PipeNet(a_pipes)
b_pipes = [
ttl.Pipe(src=(col, 0), dst=(col, slice(0, NUM_ROWS)))
for col in range(NUM_COLS)
]
mcast_b_net = ttl.PipeNet(b_pipes)
a_cb = ttl.make_dataflow_buffer_like(a, shape=(BLOCK_M, BLOCK_K), block_count=2)
b_cb = ttl.make_dataflow_buffer_like(w, shape=(BLOCK_K, BLOCK_N), block_count=2)
out_cb = ttl.make_dataflow_buffer_like(out, shape=(BLOCK_M, BLOCK_N), block_count=2)
@ttl.compute()
def compute():
node_n, node_m = ttl.node(dims=2)
for local_mb in range(m_blocks_per_node):
mb = node_m * m_blocks_per_node + local_mb
for local_nb in range(n_blocks_per_node):
nb = node_n * n_blocks_per_node + local_nb
out_blk = out_cb.reserve()
for _ in range(K_BLOCKS):
a_blk = a_cb.wait()
b_blk = b_cb.wait()
out_blk += a_blk @ b_blk
a_blk.pop()
b_blk.pop()
out_blk.push()
@ttl.datamovement()
def dm_read():
node_n, node_m = ttl.node(dims=2)
for local_mb in range(m_blocks_per_node):
mb = node_m * m_blocks_per_node + local_mb
mr = mb * BLOCK_M
for local_nb in range(n_blocks_per_node):
nb = node_n * n_blocks_per_node + local_nb
for kb in range(K_BLOCKS):
kc = kb * BLOCK_K
with a_cb.reserve() as a_blk:
mcast_a_net.if_src(lambda pipe: (
ttl.copy(a[mr:mr + BLOCK_M, kc:kc + BLOCK_K], a_blk).wait(),
ttl.copy(a_blk, pipe).wait(),
))
mcast_a_net.if_dst(lambda pipe: (
ttl.copy(pipe, a_blk).wait(),
))
@ttl.datamovement()
def dm_write():
node_n, node_m = ttl.node(dims=2)
for local_mb in range(m_blocks_per_node):
mb = node_m * m_blocks_per_node + local_mb
mr = mb * BLOCK_M
for local_nb in range(n_blocks_per_node):
nb = node_n * n_blocks_per_node + local_nb
nc = nb * BLOCK_N
for kb in range(K_BLOCKS):
kc = kb * BLOCK_K
with b_cb.reserve() as b_blk:
mcast_b_net.if_src(lambda pipe: (
ttl.copy(w[kc:kc + BLOCK_K, nc:nc + BLOCK_N], b_blk).wait(),
ttl.copy(b_blk, pipe).wait(),
))
mcast_b_net.if_dst(lambda pipe: (
ttl.copy(pipe, b_blk).wait(),
))
with out_cb.wait() as out_blk:
ttl.copy(out_blk, out[mr:mr + BLOCK_M, nc:nc + BLOCK_N]).wait()
return balanced_matmul
# ---------------------------------------------------------------------------
# Kernel: balanced matmul + relu (fused activation on last K iteration)
# ---------------------------------------------------------------------------
def make_balanced_relu_kernel(M_DIM, K_DIM, N_DIM):
M_BLOCKS = M_DIM // BLOCK_SIZE
N_BLOCKS = N_DIM // BLOCK_SIZE
K_BLOCKS = K_DIM // BLOCK_SIZE
M_BPN, NUM_ROWS = _even_split(M_BLOCKS, MAX_GRID_M)
N_BPN, NUM_COLS = _even_split(N_BLOCKS, MAX_GRID_N)
@ttl.operation(grid=(NUM_COLS, NUM_ROWS), fp32_dest_acc_en=FP32_ACC)
def balanced_matmul_relu(a, w, out):
m_blocks_per_node = M_BPN
n_blocks_per_node = N_BPN
a_pipes = [
ttl.Pipe(src=(0, row), dst=(slice(0, NUM_COLS), row))
for row in range(NUM_ROWS)
]
mcast_a_net = ttl.PipeNet(a_pipes)
b_pipes = [
ttl.Pipe(src=(col, 0), dst=(col, slice(0, NUM_ROWS)))
for col in range(NUM_COLS)
]
mcast_b_net = ttl.PipeNet(b_pipes)
a_cb = ttl.make_dataflow_buffer_like(a, shape=(BLOCK_M, BLOCK_K), block_count=2)
b_cb = ttl.make_dataflow_buffer_like(w, shape=(BLOCK_K, BLOCK_N), block_count=2)
acc_cb = ttl.make_dataflow_buffer_like(out, shape=(BLOCK_M, BLOCK_N), block_count=2)
out_cb = ttl.make_dataflow_buffer_like(out, shape=(BLOCK_M, BLOCK_N), block_count=1)
@ttl.compute()
def compute():
node_n, node_m = ttl.node(dims=2)
for local_mb in range(m_blocks_per_node):
mb = node_m * m_blocks_per_node + local_mb
for local_nb in range(n_blocks_per_node):
nb = node_n * n_blocks_per_node + local_nb
with acc_cb.reserve() as init:
init.store(ttl.math.fill(init, 0))
for kb in range(K_BLOCKS):
with (
a_cb.wait() as a_blk,
b_cb.wait() as b_blk,
acc_cb.wait() as last,
acc_cb.reserve() as next,
):
if kb < K_BLOCKS - 1:
next.store(last + a_blk @ b_blk)
else:
next.store(ttl.math.relu(last + a_blk @ b_blk))
with acc_cb.wait() as result, out_cb.reserve() as o:
o.store(result)
@ttl.datamovement()
def dm_read():
node_n, node_m = ttl.node(dims=2)
for local_mb in range(m_blocks_per_node):
mb = node_m * m_blocks_per_node + local_mb
mr = mb * BLOCK_M
for local_nb in range(n_blocks_per_node):
nb = node_n * n_blocks_per_node + local_nb
for kb in range(K_BLOCKS):
kc = kb * BLOCK_K
with a_cb.reserve() as a_blk:
mcast_a_net.if_src(lambda pipe: (
ttl.copy(a[mr:mr + BLOCK_M, kc:kc + BLOCK_K], a_blk).wait(),
ttl.copy(a_blk, pipe).wait(),
))
mcast_a_net.if_dst(lambda pipe: (
ttl.copy(pipe, a_blk).wait(),
))
@ttl.datamovement()
def dm_write():
node_n, node_m = ttl.node(dims=2)
for local_mb in range(m_blocks_per_node):
mb = node_m * m_blocks_per_node + local_mb
mr = mb * BLOCK_M
for local_nb in range(n_blocks_per_node):
nb = node_n * n_blocks_per_node + local_nb
nc = nb * BLOCK_N
for kb in range(K_BLOCKS):
kc = kb * BLOCK_K
with b_cb.reserve() as b_blk:
mcast_b_net.if_src(lambda pipe: (
ttl.copy(w[kc:kc + BLOCK_K, nc:nc + BLOCK_N], b_blk).wait(),
ttl.copy(b_blk, pipe).wait(),
))
mcast_b_net.if_dst(lambda pipe: (
ttl.copy(pipe, b_blk).wait(),
))
with out_cb.wait() as out_blk:
ttl.copy(out_blk, out[mr:mr + BLOCK_M, nc:nc + BLOCK_N]).wait()
return balanced_matmul_relu
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def bench(fn, *args, warmup=2, iters=5):
for _ in range(warmup):
fn(*args)
ttnn.synchronize_device(args[0].device())
times = []
for _ in range(iters):
t0 = time.perf_counter()
fn(*args)
ttnn.synchronize_device(args[0].device())
times.append((time.perf_counter() - t0) * 1000)
return min(times), sum(times) / len(times)
def to_device(t, device, memory_config):
return ttnn.from_torch(t.contiguous(), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT,
device=device, memory_config=memory_config)
def make_tensors(M, K, N, device, memory_config):
import torch
a = to_device(torch.randn(M, K, dtype=torch.bfloat16) * 0.02, device, memory_config)
w = to_device(torch.randn(K, N, dtype=torch.bfloat16) * 0.02, device, memory_config)
out = to_device(torch.zeros(M, N, dtype=torch.bfloat16), device, memory_config)
return a, w, out
# ---------------------------------------------------------------------------
# Shapes
# ---------------------------------------------------------------------------
SHAPES_SMALL = [
(1024, 1024, 1024, "1k^3"),
(1024, 2048, 1024, "1k x 2k x 1k"),
(2048, 2048, 2048, "2k^3"),
(2048, 4096, 2048, "2k x 4k x 2k"),
(2560, 2048, 3072, "2.5k x 2k x 3k"),
]
SHAPES_LARGE = [
(2048, 8192, 2048, "2k x 8k x 2k (long K)"),
(2560, 4096, 3072, "2.5k x 4k x 3k"),
(2560, 8192, 3072, "2.5k x 8k x 3k (120 cores)"),
(2560, 8192, 3328, "2.5k x 8k x 3.3k (130 cores)"),
(1024, 16384, 2560, "1k x 16k x 2.5k (tall K)"),
(4096, 4096, 4096, "4k^3"),
(4096, 8192, 4096, "4k x 8k x 4k"),
(8192, 8192, 8192, "8k^3"),
(10240, 8192, 13312, "10k x 8k x 13k (130 cores, 4x4)"),
]
def dealloc(*tensors):
for t in tensors:
ttnn.deallocate(t)
def bench_and_cleanup(fn, *tensors, device):
best, avg = bench(fn, *tensors)
dealloc(*tensors)
ttnn.synchronize_device(device)
time.sleep(0.01)
return best, avg
def run_matmul_shape(M, K, N, label, device, mem_config=DRAM):
mc_label = "L1" if mem_config is not DRAM else ""
mcast_fn = make_mcast_kernel(M, K, N)
a, w, o = make_tensors(M, K, N, device, mem_config)
mc_best, mc_avg = bench_and_cleanup(mcast_fn, a, w, o, device=device)
bal_fn = make_balanced_kernel(M, K, N)
a2, w2, o2 = make_tensors(M, K, N, device, mem_config)
bal_best, bal_avg = bench_and_cleanup(bal_fn, a2, w2, o2, device=device)
def ttnn_mm(a, w):
return ttnn.matmul(a, w, compute_kernel_config=TTNN_COMPUTE_CONFIG)
a3, w3, _ = make_tensors(M, K, N, device, mem_config)
ttnn_best, ttnn_avg = bench_and_cleanup(ttnn_mm, a3, w3, device=device)
mc_r = f"{mc_best / ttnn_best:.2f}x"
bal_r = f"{bal_best / ttnn_best:.2f}x"
lbl = f"{label} {mc_label}" if mc_label else label
print(
f"{lbl:<32} "
f"{mc_best:>9.2f}ms {mc_avg:>8.2f}ms "
f"{bal_best:>9.2f}ms {bal_avg:>8.2f}ms "
f"{ttnn_best:>9.2f}ms {ttnn_avg:>8.2f}ms "
f"{mc_r:>8} {bal_r:>9}"
)
return {"label": lbl, "flops": 2*M*K*N, "mcast": mc_best, "balanced": bal_best, "ttnn": ttnn_best}
def run_relu_shape(M, K, N, label, device, mem_config=DRAM):
mc_label = "L1" if mem_config is not DRAM else ""
relu_fn = make_balanced_relu_kernel(M, K, N)
a, w, o = make_tensors(M, K, N, device, mem_config)
ttl_best, ttl_avg = bench_and_cleanup(relu_fn, a, w, o, device=device)
def linear_relu(a, w):
return ttnn.linear(a, w, activation="relu", compute_kernel_config=TTNN_COMPUTE_CONFIG)
a2, w2, _ = make_tensors(M, K, N, device, mem_config)
lin_best, lin_avg = bench_and_cleanup(linear_relu, a2, w2, device=device)
def mm_relu(a, w):
return ttnn.relu(ttnn.matmul(a, w, compute_kernel_config=TTNN_COMPUTE_CONFIG))
a3, w3, _ = make_tensors(M, K, N, device, mem_config)
sep_best, sep_avg = bench_and_cleanup(mm_relu, a3, w3, device=device)
r_lin = f"{ttl_best / lin_best:.2f}x"
r_sep = f"{ttl_best / sep_best:.2f}x"
lbl = f"{label} {mc_label}" if mc_label else label
print(
f"{lbl:<32} "
f"{ttl_best:>9.2f}ms {ttl_avg:>8.2f}ms "
f"{lin_best:>10.2f}ms {lin_avg:>8.2f}ms "
f"{sep_best:>9.2f}ms {sep_avg:>8.2f}ms "
f"{r_lin:>8} {r_sep:>8}"
)
return {"label": lbl, "flops": 2*M*K*N, "ttl_relu": ttl_best, "ttnn_linear": lin_best, "mm_relu": sep_best}
def save_plot(matmul_data, relu_data, path="/tmp/bench_matmul_sweep.png"):
"""Generate performance plots with shape labels on x-axis."""
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
except ImportError:
print("matplotlib not available, skipping plot")
return
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(16, 18))
# --- Top: Matmul sweep ---
# Group by shape label (without L1 suffix), show DRAM and L1 variants
shape_order = []
series = {
"ttnn DRAM": {}, "ttnn L1": {},
"ttl balanced DRAM": {}, "ttl balanced L1": {},
"ttl mcast DRAM": {}, "ttl mcast L1": {},
}
for d in matmul_data:
lbl = d["label"].strip()
is_l1 = lbl.endswith("L1")
base = lbl.replace(" L1", "").strip() if is_l1 else lbl
suffix = " L1" if is_l1 else " DRAM"
if base not in shape_order:
shape_order.append(base)
if d.get("ttnn") is not None:
series[f"ttnn{suffix}"][base] = d["ttnn"]
if d.get("balanced") is not None:
series[f"ttl balanced{suffix}"][base] = d["balanced"]
if d.get("mcast") is not None:
series[f"ttl mcast{suffix}"][base] = d["mcast"]
x = np.arange(len(shape_order))
colors = {
"ttnn DRAM": "#1f77b4", "ttnn L1": "#aec7e8",
"ttl balanced DRAM": "#2ca02c", "ttl balanced L1": "#98df8a",
"ttl mcast DRAM": "#ff7f0e", "ttl mcast L1": "#ffbb78",
}
markers = {
"ttnn DRAM": "s", "ttnn L1": "s",
"ttl balanced DRAM": "o", "ttl balanced L1": "o",
"ttl mcast DRAM": "^", "ttl mcast L1": "^",
}
linestyles = {k: "-" if "DRAM" in k else "--" for k in series}
for name, data in series.items():
if not data:
continue
vals = [data.get(s) for s in shape_order]
xs = [i for i, v in enumerate(vals) if v is not None]
ys = [v for v in vals if v is not None]
if ys:
ax1.plot(xs, ys, marker=markers[name], color=colors[name],
linestyle=linestyles[name], label=name, linewidth=1.5,
markersize=5, alpha=0.85)
ax1.set_xticks(x)
ax1.set_xticklabels(shape_order, rotation=45, ha="right", fontsize=8)
ax1.set_ylabel("Time (ms)")
ax1.set_title("Matmul: TTL vs TTNN (lower is better)")
ax1.legend(fontsize=7, ncol=3, loc="upper left")
ax1.set_yscale("log")
ax1.grid(True, alpha=0.3, axis="y")
# --- Bottom: Fused relu ---
relu_shapes = []
relu_series = {
"ttl relu DRAM": {}, "ttl relu L1": {},
"ttnn.linear DRAM": {}, "ttnn.linear L1": {},
"ttnn.matmul+ttnn.relu DRAM": {}, "ttnn.matmul+ttnn.relu L1": {},
}
for d in relu_data:
lbl = d["label"].strip()
is_l1 = lbl.endswith("L1")
base = lbl.replace(" L1", "").strip() if is_l1 else lbl
suffix = " L1" if is_l1 else " DRAM"
if base not in relu_shapes:
relu_shapes.append(base)
relu_series[f"ttl relu{suffix}"][base] = d["ttl_relu"]
relu_series[f"ttnn.linear{suffix}"][base] = d["ttnn_linear"]
relu_series[f"ttnn.matmul+ttnn.relu{suffix}"][base] = d["mm_relu"]
if relu_shapes:
x2 = np.arange(len(relu_shapes))
relu_colors = {
"ttl relu DRAM": "#2ca02c", "ttl relu L1": "#98df8a",
"ttnn.linear DRAM": "#1f77b4", "ttnn.linear L1": "#aec7e8",
"ttnn.matmul+ttnn.relu DRAM": "#d62728", "ttnn.matmul+ttnn.relu L1": "#ff9896",
}
relu_markers = {
"ttl relu DRAM": "o", "ttl relu L1": "o",
"ttnn.linear DRAM": "s", "ttnn.linear L1": "s",
"ttnn.matmul+ttnn.relu DRAM": "^", "ttnn.matmul+ttnn.relu L1": "^",
}
relu_ls = {k: "-" if "DRAM" in k else "--" for k in relu_series}
for name, data in relu_series.items():
if not data:
continue
vals = [data.get(s) for s in relu_shapes]
xs = [i for i, v in enumerate(vals) if v is not None]
ys = [v for v in vals if v is not None]
if ys:
ax2.plot(xs, ys, marker=relu_markers[name], color=relu_colors[name],
linestyle=relu_ls[name], label=name, linewidth=1.5,
markersize=5, alpha=0.85)
ax2.set_xticks(x2)
ax2.set_xticklabels(relu_shapes, rotation=45, ha="right", fontsize=8)
ax2.set_ylabel("Time (ms)")
ax2.set_title("Fused Matmul+ReLU: TTL vs TTNN (lower is better)")
ax2.legend(fontsize=7, ncol=3, loc="upper left")
ax2.set_yscale("log")
ax2.grid(True, alpha=0.3, axis="y")
# --- Bottom: Ratio plot (TTL / TTNN baseline), all shapes merged ---
# Matmul lines: ratio vs ttnn DRAM
# Relu lines: ratio vs ttnn.matmul+ttnn.relu DRAM
all_shapes = []
ratio_series = {}
# Matmul ratios: all TTL variants / ttnn DRAM
ttnn_dram = series.get("ttnn DRAM", {})
for ttl_key in ("ttl mcast DRAM", "ttl mcast L1", "ttl balanced DRAM", "ttl balanced L1"):
ttl_data = series.get(ttl_key, {})
if not ttl_data:
continue
ratios = {}
for shape in shape_order:
ttl_val = ttl_data.get(shape)
base_val = ttnn_dram.get(shape)
if ttl_val is not None and base_val is not None:
ratios[shape] = ttl_val / base_val
if shape not in all_shapes:
all_shapes.append(shape)
if ratios:
ratio_series[ttl_key] = ratios
# Relu ratios: all TTL relu variants / ttnn.matmul+ttnn.relu DRAM
mm_relu_dram = relu_series.get("ttnn.matmul+ttnn.relu DRAM", {})
for ttl_key in ("ttl relu DRAM", "ttl relu L1"):
ttl_data = relu_series.get(ttl_key, {})
if not ttl_data:
continue
ratios = {}
for shape in relu_shapes:
ttl_val = ttl_data.get(shape)
base_val = mm_relu_dram.get(shape)
if ttl_val is not None and base_val is not None:
ratios[shape] = ttl_val / base_val
if shape not in all_shapes:
all_shapes.append(shape)
if ratios:
ratio_series[ttl_key] = ratios
if all_shapes:
x3 = np.arange(len(all_shapes))
ratio_colors = {
"ttl mcast DRAM": "#ff7f0e", "ttl mcast L1": "#ffbb78",
"ttl balanced DRAM": "#2ca02c", "ttl balanced L1": "#98df8a",
"ttl relu DRAM": "#9467bd", "ttl relu L1": "#c5b0d5",
}
ratio_markers = {
"ttl mcast DRAM": "^", "ttl mcast L1": "^",
"ttl balanced DRAM": "o", "ttl balanced L1": "o",
"ttl relu DRAM": "s", "ttl relu L1": "s",
}
ratio_ls = {k: "-" if "DRAM" in k else "--" for k in ratio_colors}
for name, data in ratio_series.items():
if not data:
continue
vals = [data.get(s) for s in all_shapes]
xs = [i for i, v in enumerate(vals) if v is not None]
ys = [v for v in vals if v is not None]
if ys:
ax3.plot(xs, ys, marker=ratio_markers.get(name, "o"),
color=ratio_colors.get(name, "#333333"),
linestyle=ratio_ls.get(name, "-"), label=name,
linewidth=1.5, markersize=5, alpha=0.85)
ax3.axhline(y=1.0, color="black", linestyle=":", linewidth=1, alpha=0.7)
ax3.set_xticks(x3)
ax3.set_xticklabels(all_shapes, rotation=45, ha="right", fontsize=8)
ax3.set_ylabel("Ratio (TTL / TTNN)")
ax3.set_title("TTL / TTNN Ratio (matmul vs ttnn.matmul DRAM, relu vs ttnn.matmul+ttnn.relu DRAM) -- below 1.0 = TTL wins")
ax3.legend(fontsize=7, ncol=4, loc="upper right")
ax3.grid(True, alpha=0.3, axis="y")
plt.tight_layout()
plt.savefig(path, dpi=150)
print(f"Plot saved to {path}")
if __name__ == "__main__":
import torch
torch.manual_seed(42)
device = ttnn.open_device(device_id=0)
matmul_data = []
relu_data = []
try:
header = (
f"{'Shape':<32} "
f"{'mcast best':>11} {'mcast avg':>10} "
f"{'bal best':>11} {'bal avg':>10} "
f"{'ttnn best':>11} {'ttnn avg':>10} "
f"{'mc/ttnn':>8} {'bal/ttnn':>9}"
)
sep = "-" * len(header)
# --- Part 1: Matmul sweep ---
acc_label = "f32" if FP32_ACC else "bf16"
print(f"=== Mcast Matmul Sweep Benchmark (acc={acc_label}) ===")
print(f"Block: {BLOCK_M}x{BLOCK_N}x{BLOCK_K}, tile: {TILE}")
print()
print(header)
print(sep)
for M, K, N, label in SHAPES_SMALL:
matmul_data.append(run_matmul_shape(M, K, N, label, device))
matmul_data.append(run_matmul_shape(M, K, N, label, device, mem_config=L1))
for M, K, N, label in SHAPES_LARGE:
matmul_data.append(run_matmul_shape(M, K, N, label, device))
print(sep)
# --- Part 2: Fused matmul+relu (balanced only) ---
print()
print("=== Fused Matmul+ReLU Benchmark (balanced, 1 block/core) ===")
relu_header = (
f"{'Shape':<32} "
f"{'ttl relu':>11} {'ttl avg':>10} "
f"{'ttnn.linear':>12} {'lin avg':>10} "
f"{'ttnn.matmul+ttnn.relu':>22} {'sep avg':>10} "
f"{'ttl/lin':>8} {'ttl/sep':>8}"
)
relu_sep = "-" * len(relu_header)
print(relu_header)
print(relu_sep)
for M, K, N, label in SHAPES_SMALL:
relu_data.append(run_relu_shape(M, K, N, label, device))
relu_data.append(run_relu_shape(M, K, N, label, device, mem_config=L1))
for M, K, N, label in SHAPES_LARGE:
relu_data.append(run_relu_shape(M, K, N, label, device))
print(relu_sep)
# --- Plot ---
save_plot(matmul_data, relu_data)
# --- PCC smoke test ---
import torch
M_t, K_t, N_t = 2560, 8192, 3328
torch.manual_seed(42)
a_torch = torch.randn(M_t, K_t, dtype=torch.bfloat16) * 0.02
w_torch = torch.randn(K_t, N_t, dtype=torch.bfloat16) * 0.02
ref = a_torch.float() @ w_torch.float()
def to_dev(t):
return ttnn.from_torch(t.contiguous(), dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT, device=device,
memory_config=DRAM)
def calc_pcc(result):
return torch.corrcoef(torch.stack([result.flatten(), ref.flatten()]))[0, 1].item()
# TTL mcast
out_dev = to_dev(torch.zeros(M_t, N_t, dtype=torch.bfloat16))
mcast_fn = make_mcast_kernel(M_t, K_t, N_t)
mcast_fn(to_dev(a_torch), to_dev(w_torch), out_dev)
ttl_pcc = calc_pcc(ttnn.to_torch(out_dev).float())
# TTNN
ttnn_out = ttnn.matmul(to_dev(a_torch), to_dev(w_torch),
compute_kernel_config=TTNN_COMPUTE_CONFIG)
ttnn_pcc = calc_pcc(ttnn.to_torch(ttnn_out).float())
print()
print(f"=== PCC Smoke Test ({M_t}x{K_t}x{N_t}, acc={acc_label}) ===")
print(f" TTL mcast: {ttl_pcc:.6f}")
print(f" TTNN: {ttnn_pcc:.6f}")
assert ttl_pcc > 0.99, f"TTL PCC too low: {ttl_pcc}"
assert ttnn_pcc > 0.99, f"TTNN PCC too low: {ttnn_pcc}"
print(" PASS")
print("=== Done ===")
finally:
ttnn.close_device(device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment