Last active
August 13, 2025 03:56
-
-
Save Chillee/e2b07157caeade8c6b0bdf463d10f833 to your computer and use it in GitHub Desktop.
Cutlass Thread-Value Layout Visualizer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import cutlass.cute as cute | |
import cutlass | |
def visualize_tv_layout( | |
tiler_mn: tuple[int, int], | |
tv_layout, # (((thr_shape),(val_shape)), | |
# ((thr_stride),(val_stride))) | |
*, | |
font_size: int = 10, | |
cell_px: int = 70, | |
grid_lw: float = 1.5, | |
color_fn=None, # optional (tid,vid) -> colour | |
): | |
"""Draw a T/V checkerboard for an arbitrary TV layout.""" | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.colors as mcolors | |
# ----------------------------------------------------------------- | |
# 1) Build a real CuTe layout from the tuple the user passed | |
# ----------------------------------------------------------------- | |
shape, stride = tv_layout | |
if isinstance(shape[0], int): | |
n_thr = shape[0] | |
else: | |
n_thr = math.prod(shape[0]) | |
if isinstance(shape[1], int): | |
n_val = shape[1] | |
else: | |
n_val = math.prod(shape[1]) | |
M, N = tiler_mn | |
thr_ids = np.full((M, N), -1, dtype=int) | |
val_ids = np.full((M, N), -1, dtype=int) | |
filled = np.zeros((M, N), dtype=bool) | |
# ----------------------------------------------------------------- | |
# 2) Query CuTe for every (tid, vid) → (m,n) | |
# ----------------------------------------------------------------- | |
@cute.jit | |
def g(): | |
tv_layout = cute.make_layout(shape, stride=stride) | |
tid_vals = [] | |
for tid in cutlass.range_constexpr(n_thr): | |
vid_vals = [] | |
for vid in cutlass.range_constexpr(n_val): | |
vid_vals.append(tv_layout((tid, vid))) | |
tid_vals.append(vid_vals) | |
return tid_vals | |
vals = g() | |
for tid in range(n_thr): | |
for vid in range(n_val): | |
pos = vals[tid][vid] | |
n = pos // M | |
m = pos % M | |
if filled[m, n]: | |
continue | |
thr_ids[m, n] = tid | |
val_ids[m, n] = vid | |
filled[m, n] = True | |
# ----------------------------------------------------------------- | |
# 3) Colours (default: pastel per-thread) | |
# ----------------------------------------------------------------- | |
if color_fn is None: | |
pastel = plt.cm.Set3.colors | |
cmap = (pastel * ((n_thr // 12) + 1))[:n_thr] | |
color_fn = lambda t, v: cmap[t % len(cmap)] | |
bg_rgb = np.zeros((M, N, 3)) | |
for m in range(M): | |
for n in range(N): | |
tid = thr_ids[m, n] | |
if tid >= 0: | |
bg_rgb[m, n] = mcolors.to_rgb(color_fn(tid, val_ids[m, n])) | |
# ----------------------------------------------------------------- | |
# 4) Draw | |
# ----------------------------------------------------------------- | |
fig_w, fig_h = N * cell_px / 100, M * cell_px / 100 | |
fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=100) | |
ax.imshow(bg_rgb, interpolation="none") | |
for m in range(M): | |
for n in range(N): | |
if thr_ids[m, n] >= 0: | |
ax.text( | |
n, m, f"T{thr_ids[m,n]}\nV{val_ids[m,n]}", | |
ha="center", va="center", | |
fontsize=font_size, weight="bold" | |
) | |
ax.set_xticks(np.arange(N + 1) - 0.5) | |
ax.set_yticks(np.arange(M + 1) - 0.5) | |
ax.set_xticklabels([str(i) for i in range(N + 1)]) # Show x tick labels | |
ax.set_yticklabels([str(i) for i in range(M + 1)]) # Show y tick labels | |
ax.tick_params(axis='both', which='both', length=6, width=1) # Make ticks more visible | |
ax.tick_params(axis='x', which='both', top=True, bottom=False, labeltop=True, labelbottom=False) # Show ticks and labels on top | |
ax.tick_params(axis='y', which='both', left=True, right=False) # Show ticks on left | |
ax.grid(which="major", color="black", linewidth=grid_lw) | |
ax.set_xlim(-.5, N -.5); ax.set_ylim(M -.5, -.5) | |
ax.set_title(f"tv_layout = {tv_layout}", fontsize=font_size + 2, pad=12) | |
plt.tight_layout() | |
plt.savefig("tv_layout.svg") | |
tiler_mn = (8, 8) | |
tv = ( | |
((2, 2, 2), (2, 2, 2)), # thr_shape / val_shape | |
((1, 16, 4), (8, 2, 32)), # thr_stride / val_stride | |
) | |
visualize_tv_layout(tiler_mn, tv) |
Author
Chillee
commented
Jul 29, 2025

tv = (
((2, 2, 2), (2, 2, 2)), # thr_shape / thr_stride
((1, 16, 4), (8, 2, 32)), # val_shape / val_stride
)
Could you update the comment to be # thr_shape / val_shape
then # thr_stride / val_stride
? I got tripped up by this at first.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment