Skip to content

Instantly share code, notes, and snippets.

@Chillee
Last active August 13, 2025 03:56
Show Gist options
  • Save Chillee/e2b07157caeade8c6b0bdf463d10f833 to your computer and use it in GitHub Desktop.
Save Chillee/e2b07157caeade8c6b0bdf463d10f833 to your computer and use it in GitHub Desktop.
Cutlass Thread-Value Layout Visualizer
import math
import cutlass.cute as cute
import cutlass
def visualize_tv_layout(
tiler_mn: tuple[int, int],
tv_layout, # (((thr_shape),(val_shape)),
# ((thr_stride),(val_stride)))
*,
font_size: int = 10,
cell_px: int = 70,
grid_lw: float = 1.5,
color_fn=None, # optional (tid,vid) -> colour
):
"""Draw a T/V checkerboard for an arbitrary TV layout."""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
# -----------------------------------------------------------------
# 1) Build a real CuTe layout from the tuple the user passed
# -----------------------------------------------------------------
shape, stride = tv_layout
if isinstance(shape[0], int):
n_thr = shape[0]
else:
n_thr = math.prod(shape[0])
if isinstance(shape[1], int):
n_val = shape[1]
else:
n_val = math.prod(shape[1])
M, N = tiler_mn
thr_ids = np.full((M, N), -1, dtype=int)
val_ids = np.full((M, N), -1, dtype=int)
filled = np.zeros((M, N), dtype=bool)
# -----------------------------------------------------------------
# 2) Query CuTe for every (tid, vid) → (m,n)
# -----------------------------------------------------------------
@cute.jit
def g():
tv_layout = cute.make_layout(shape, stride=stride)
tid_vals = []
for tid in cutlass.range_constexpr(n_thr):
vid_vals = []
for vid in cutlass.range_constexpr(n_val):
vid_vals.append(tv_layout((tid, vid)))
tid_vals.append(vid_vals)
return tid_vals
vals = g()
for tid in range(n_thr):
for vid in range(n_val):
pos = vals[tid][vid]
n = pos // M
m = pos % M
if filled[m, n]:
continue
thr_ids[m, n] = tid
val_ids[m, n] = vid
filled[m, n] = True
# -----------------------------------------------------------------
# 3) Colours (default: pastel per-thread)
# -----------------------------------------------------------------
if color_fn is None:
pastel = plt.cm.Set3.colors
cmap = (pastel * ((n_thr // 12) + 1))[:n_thr]
color_fn = lambda t, v: cmap[t % len(cmap)]
bg_rgb = np.zeros((M, N, 3))
for m in range(M):
for n in range(N):
tid = thr_ids[m, n]
if tid >= 0:
bg_rgb[m, n] = mcolors.to_rgb(color_fn(tid, val_ids[m, n]))
# -----------------------------------------------------------------
# 4) Draw
# -----------------------------------------------------------------
fig_w, fig_h = N * cell_px / 100, M * cell_px / 100
fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=100)
ax.imshow(bg_rgb, interpolation="none")
for m in range(M):
for n in range(N):
if thr_ids[m, n] >= 0:
ax.text(
n, m, f"T{thr_ids[m,n]}\nV{val_ids[m,n]}",
ha="center", va="center",
fontsize=font_size, weight="bold"
)
ax.set_xticks(np.arange(N + 1) - 0.5)
ax.set_yticks(np.arange(M + 1) - 0.5)
ax.set_xticklabels([str(i) for i in range(N + 1)]) # Show x tick labels
ax.set_yticklabels([str(i) for i in range(M + 1)]) # Show y tick labels
ax.tick_params(axis='both', which='both', length=6, width=1) # Make ticks more visible
ax.tick_params(axis='x', which='both', top=True, bottom=False, labeltop=True, labelbottom=False) # Show ticks and labels on top
ax.tick_params(axis='y', which='both', left=True, right=False) # Show ticks on left
ax.grid(which="major", color="black", linewidth=grid_lw)
ax.set_xlim(-.5, N -.5); ax.set_ylim(M -.5, -.5)
ax.set_title(f"tv_layout = {tv_layout}", fontsize=font_size + 2, pad=12)
plt.tight_layout()
plt.savefig("tv_layout.svg")
tiler_mn = (8, 8)
tv = (
((2, 2, 2), (2, 2, 2)), # thr_shape / val_shape
((1, 16, 4), (8, 2, 32)), # thr_stride / val_stride
)
visualize_tv_layout(tiler_mn, tv)
@Chillee
Copy link
Author

Chillee commented Jul 29, 2025

image

@awgu
Copy link

awgu commented Jul 29, 2025

tv = (
    ((2, 2, 2), (2, 2, 2)),  # thr_shape / thr_stride
    ((1, 16, 4), (8, 2, 32)),  # val_shape / val_stride
)

Could you update the comment to be # thr_shape / val_shape then # thr_stride / val_stride? I got tripped up by this at first.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment