Skip to content

Instantly share code, notes, and snippets.

@TomAugspurger
Created April 8, 2026 15:54
Show Gist options
  • Select an option

  • Save TomAugspurger/355766e1ca9e0a611d3aadeb556b18a3 to your computer and use it in GitHub Desktop.

Select an option

Save TomAugspurger/355766e1ca9e0a611d3aadeb556b18a3 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "polars>=1.0.0",
# "numpy>=1.24",
# "rich>=13",
# "pyarrow>=14",
# ]
# ///
#
# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0
"""
cudf-polars-report analyzer.
This tool analyzes Nsight Systems reports (.nsys-rep) of cudf-polars' benchmark runs.
uv run cpr.py summary report.nsys-rep
uv run cpr.py io report.nsys-rep --query 0 --iteration 1
Requires ``nsys`` on PATH to create the SQLite export unless one already exists.
See ``cpr.py summary --help`` for how Host, Wall, and Device times differ.
"""
from __future__ import annotations
import argparse
import json
import shutil
import numpy as np
import sqlite3
import struct
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import polars as pl
from rich.console import Console
from rich.table import Table
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# Case-insensitive; full-line match for Query N - Iteration M (polars/rust regex)
QUERY_ITER_REGEX = r"(?i)^Query\s+(\d+)\s+-\s+Iteration\s+(\d+)\s*$"
DEFAULT_EXCLUDE_LABELS: frozenset[str] = frozenset({"ConvertIR", "ExecuteIR"})
# nvtxPayloadType_t (nvToolsExt.h)
NVTX_PAYLOAD_UNSIGNED_INT64 = 1
NVTX_PAYLOAD_INT64 = 2
KVIKIO_READ_NAME_SUBSTR = "RemoteHandle::read"
CACHE_TABLES = ("nvtx_events", "enum_nsys_event_type", "kernels_runtime")
# -----------------------------------------------------------------------------
# nsys export
# -----------------------------------------------------------------------------
def default_sqlite_path(nsys_rep: Path) -> Path:
return nsys_rep.with_suffix(".sqlite")
def run_nsys_export(nsys_rep: Path, sqlite_out: Path, *, force: bool) -> None:
if sqlite_out.exists() and not force:
return
sqlite_out.parent.mkdir(parents=True, exist_ok=True)
cmd = [
"nsys",
"export",
"--type",
"sqlite",
"--force-overwrite=true",
"-q",
"true",
"-o",
str(sqlite_out),
str(nsys_rep),
]
subprocess.run(cmd, check=True)
def ensure_nsys_available() -> None:
if shutil.which("nsys") is None:
print(
"error: 'nsys' not found on PATH. Install Nsight Systems CLI.",
file=sys.stderr,
)
raise SystemExit(127)
# -----------------------------------------------------------------------------
# SQLite helpers
# -----------------------------------------------------------------------------
def sqlite_table_columns(conn: sqlite3.Connection, table: str) -> frozenset[str]:
cur = conn.execute(f"PRAGMA table_info({table})")
return frozenset(str(row[1]) for row in cur.fetchall())
def query_polars(conn: sqlite3.Connection, sql: str, params: tuple[Any, ...] = ()) -> pl.DataFrame:
cur = conn.execute(sql, params)
if cur.description is None:
return pl.DataFrame()
cols = [d[0] for d in cur.description]
rows = cur.fetchall()
if not rows:
return pl.DataFrame({c: [] for c in cols})
return pl.DataFrame({cols[i]: [r[i] for r in rows] for i in range(len(cols))})
# -----------------------------------------------------------------------------
# NVTX payload decoding (UINT64 byte counts from kvikio)
# -----------------------------------------------------------------------------
def decode_nvtx_uint64(
raw: object,
payload_type: int | None,
) -> int | None:
if raw is None:
return None
if payload_type is not None and payload_type not in (
0,
NVTX_PAYLOAD_UNSIGNED_INT64,
NVTX_PAYLOAD_INT64,
):
return None
if isinstance(raw, int):
if raw < 0 and payload_type == NVTX_PAYLOAD_INT64:
return None
return int(raw) & 0xFFFFFFFFFFFFFFFF
if isinstance(raw, (bytes, bytearray, memoryview)):
buf = bytes(raw)
if len(buf) >= 8:
return struct.unpack_from("<Q", buf, 0)[0]
if len(buf) >= 4:
return int(struct.unpack_from("<I", buf, 0)[0])
return None
# -----------------------------------------------------------------------------
# Interval helpers
# -----------------------------------------------------------------------------
def _interval_union_ns_sorted_arrays(s: np.ndarray, e: np.ndarray) -> int:
"""Union length of ``[s[i], e[i])`` with rows sorted by ``s`` ascending."""
if s.size == 0:
return 0
total = 0
cs, ce = int(s[0]), int(e[0])
for i in range(1, s.size):
si, ei = int(s[i]), int(e[i])
if si > ce:
total += ce - cs
cs, ce = si, ei
else:
ce = max(ce, ei)
total += ce - cs
return total
def merge_intervals_union_ns(iv: pl.DataFrame) -> int:
"""
Total nanoseconds in the union of ``[start, end)`` intervals.
Expects columns ``start`` and ``end`` (cast to ``Int64``). Uses sorted merge
with NumPy arrays from Polars (no ``.to_list()``).
"""
df = (
iv.select(pl.col("start").cast(pl.Int64), pl.col("end").cast(pl.Int64))
.filter(pl.col("end") > pl.col("start"))
.sort("start")
)
if df.is_empty():
return 0
s = df.get_column("start").to_numpy()
e = df.get_column("end").to_numpy()
return _interval_union_ns_sorted_arrays(s, e)
def wall_union_ns_per_operator(df: pl.DataFrame) -> pl.DataFrame:
"""
Per ``label_clean``, nanoseconds covered by the union of ``[start, end)``
intervals on the global timeline (all instances of that operator merged).
"""
if df.is_empty():
return pl.DataFrame(schema={"operator": pl.Utf8, "wall_union_ns": pl.Int64})
return (
df.group_by("label_clean", maintain_order=False)
.map_groups(
lambda g: pl.DataFrame(
{
"operator": [str(g.get_column("label_clean")[0])],
"wall_union_ns": [merge_intervals_union_ns(g.select("start", "end"))],
}
)
)
.select(pl.col("operator"), pl.col("wall_union_ns"))
)
# -----------------------------------------------------------------------------
# Extracted data loading (sqlite -> polars), optional parquet cache
# -----------------------------------------------------------------------------
@dataclass(frozen=True)
class NsysSchema:
nvtx_uint64_col: str | None # "uint64Value" or None
nvtx_payload_col: str | None # "payload"
nvtx_payload_type_col: str | None
def detect_nvtx_payload_schema(cols: frozenset[str]) -> NsysSchema:
if "uint64Value" in cols:
return NsysSchema(
nvtx_uint64_col="uint64Value",
nvtx_payload_col=None,
nvtx_payload_type_col=None,
)
p = "payload" if "payload" in cols else None
pt = "payloadType" if "payloadType" in cols else None
return NsysSchema(None, p, pt)
def nvtx_range_event_type_ids(conn: sqlite3.Connection) -> list[int]:
df = query_polars(
conn,
"""
SELECT id FROM ENUM_NSYS_EVENT_TYPE
WHERE name IN (
'NvtxPushPopRange',
'NvtxStartEndRange',
'NvtxtPushPopRange',
'NvtxtStartEndRange'
)
""",
)
if df.is_empty():
return []
return df["id"].cast(pl.Int64).to_list()
def domain_create_event_type_id(conn: sqlite3.Connection) -> int | None:
df = query_polars(
conn,
"SELECT id FROM ENUM_NSYS_EVENT_TYPE WHERE name = 'NvtxDomainCreate' LIMIT 1",
)
if df.is_empty():
return None
return int(df["id"][0])
def domain_ids_for_name(conn: sqlite3.Connection, domain_name: str) -> list[int]:
dcid = domain_create_event_type_id(conn)
if dcid is None:
return []
df = query_polars(
conn,
"""
SELECT DISTINCT n.domainId
FROM NVTX_EVENTS n
LEFT JOIN StringIds t ON n.textId = t.id
WHERE n.eventType = ?
AND COALESCE(t.value, n.text) = ?
""",
(dcid, domain_name),
)
if df.is_empty():
return []
return df["domainId"].cast(pl.Int64).unique().to_list()
def build_nvtx_extract_sql(schema: NsysSchema) -> str:
if schema.nvtx_uint64_col:
payload_sel = "n.uint64Value AS payload_raw"
ptype_sel = "CAST(NULL AS INTEGER) AS payload_type"
elif schema.nvtx_payload_col:
payload_sel = f"n.{schema.nvtx_payload_col} AS payload_raw"
ptype_sel = (
f"n.{schema.nvtx_payload_type_col} AS payload_type"
if schema.nvtx_payload_type_col
else "CAST(NULL AS INTEGER) AS payload_type"
)
else:
payload_sel = "CAST(NULL AS BLOB) AS payload_raw"
ptype_sel = "CAST(NULL AS INTEGER) AS payload_type"
return f"""
SELECT n.start, n.end, n.globalTid, n.domainId, n.eventType,
COALESCE(t.value, n.text) AS label,
{payload_sel}, {ptype_sel}
FROM NVTX_EVENTS n
LEFT JOIN StringIds t ON n.textId = t.id
WHERE n.eventType IN ({{et_placeholders}})
AND n.end IS NOT NULL
"""
def load_enum_table(conn: sqlite3.Connection) -> pl.DataFrame:
return query_polars(conn, "SELECT id, name FROM ENUM_NSYS_EVENT_TYPE")
def load_nvtx_events_extract(
conn: sqlite3.Connection,
event_type_ids: list[int],
) -> pl.DataFrame:
if not event_type_ids:
return pl.DataFrame(
schema={
"start": pl.Int64,
"end": pl.Int64,
"globalTid": pl.Int64,
"domainId": pl.Int64,
"eventType": pl.Int64,
"label": pl.Utf8,
"payload_raw": pl.Binary,
"payload_type": pl.Int64,
}
)
cols = sqlite_table_columns(conn, "NVTX_EVENTS")
schema = detect_nvtx_payload_schema(cols)
ph = ",".join("?" * len(event_type_ids))
sql_template = build_nvtx_extract_sql(schema)
sql = sql_template.replace("{et_placeholders}", ph)
df = query_polars(conn, sql, tuple(event_type_ids))
# Normalize dtypes
for name in ("start", "end", "globalTid", "domainId", "eventType"):
if name in df.columns:
df = df.with_columns(pl.col(name).cast(pl.Int64))
if "payload_type" in df.columns:
df = df.with_columns(pl.col("payload_type").cast(pl.Int64))
return df
def launch_name_ids(conn: sqlite3.Connection) -> list[int]:
df = query_polars(
conn,
"""
SELECT id FROM StringIds
WHERE value LIKE 'cudaLaunchKernel%' OR value LIKE 'cudaLaunch%'
""",
)
if df.is_empty():
return []
return df["id"].cast(pl.Int64).to_list()
def load_kernels_runtime(conn: sqlite3.Connection) -> pl.DataFrame:
"""Kernel intervals joined to runtime launch rows (for globalTid and launch time)."""
lids = launch_name_ids(conn)
if not lids:
return pl.DataFrame(
schema={
"k_start": pl.Int64,
"k_end": pl.Int64,
"launch_start": pl.Int64,
"globalTid": pl.Int64,
}
)
if not sqlite_table_columns(conn, "CUPTI_ACTIVITY_KIND_KERNEL"):
return pl.DataFrame(
schema={
"k_start": pl.Int64,
"k_end": pl.Int64,
"launch_start": pl.Int64,
"globalTid": pl.Int64,
}
)
ph = ",".join("?" * len(lids))
sql = f"""
SELECT k.start AS k_start, k.end AS k_end, r.start AS launch_start, r.globalTid
FROM CUPTI_ACTIVITY_KIND_KERNEL k
INNER JOIN CUPTI_ACTIVITY_KIND_RUNTIME r
ON k.correlationId = r.correlationId
WHERE r.nameId IN ({ph})
AND r.globalTid IS NOT NULL
"""
try:
df = query_polars(conn, sql, tuple(lids))
except Exception:
return pl.DataFrame(
schema={
"k_start": pl.Int64,
"k_end": pl.Int64,
"launch_start": pl.Int64,
"globalTid": pl.Int64,
}
)
if df.is_empty():
return df
for name in ("k_start", "k_end", "launch_start", "globalTid"):
df = df.with_columns(pl.col(name).cast(pl.Int64))
return df
def cache_key_segment(query: int | None, iteration: int | None) -> str:
q = "all" if query is None else str(query)
i = "all" if iteration is None else str(iteration)
return f"{q}-{i}"
def parquet_cache_paths(
nsys_rep: Path,
segment: str,
) -> dict[str, Path]:
base = nsys_rep.parent / f"{nsys_rep.stem}.parquet"
return {t: base / f"{segment}-{t}.parquet" for t in CACHE_TABLES}
def sqlite_mtime(path: Path) -> float:
return path.stat().st_mtime
def cache_is_fresh(sqlite_path: Path, paths: dict[str, Path]) -> bool:
if not sqlite_path.is_file():
return False
sm = sqlite_mtime(sqlite_path)
for p in paths.values():
if not p.is_file():
return False
if sqlite_mtime(p) < sm:
return False
return True
def load_extracted_tables(
conn: sqlite3.Connection,
event_type_ids: list[int],
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
nvtx = load_nvtx_events_extract(conn, event_type_ids)
enum_t = load_enum_table(conn)
kr = load_kernels_runtime(conn)
return nvtx, enum_t, kr
def maybe_read_parquet_cache(
paths: dict[str, Path],
sqlite_path: Path,
*,
use_cache: bool,
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame] | None:
if not use_cache:
return None
if not cache_is_fresh(sqlite_path, paths):
return None
return (
pl.read_parquet(paths["nvtx_events"]),
pl.read_parquet(paths["enum_nsys_event_type"]),
pl.read_parquet(paths["kernels_runtime"]),
)
def write_parquet_cache(
paths: dict[str, Path],
nvtx: pl.DataFrame,
enum_t: pl.DataFrame,
kr: pl.DataFrame,
) -> None:
paths["nvtx_events"].parent.mkdir(parents=True, exist_ok=True)
nvtx.write_parquet(paths["nvtx_events"])
enum_t.write_parquet(paths["enum_nsys_event_type"])
kr.write_parquet(paths["kernels_runtime"])
# -----------------------------------------------------------------------------
# Query / iteration windows
# -----------------------------------------------------------------------------
def load_query_iteration_windows(
nvtx: pl.DataFrame,
domain_ids: list[int],
event_type_ids: list[int],
) -> pl.DataFrame:
"""Rows for ``Query N - Iteration M`` NVTX ranges: start, end, query_num, iteration_num."""
empty = pl.DataFrame(
schema={
"start": pl.Int64,
"end": pl.Int64,
"query_num": pl.Int64,
"iteration_num": pl.Int64,
}
)
if not domain_ids or not event_type_ids or nvtx.is_empty():
return empty
stripped = pl.col("label").cast(pl.Utf8).str.strip_chars()
return (
nvtx.filter(
pl.col("domainId").is_in(domain_ids)
& pl.col("eventType").is_in(event_type_ids)
& pl.col("label").is_not_null()
)
.with_columns(
stripped.str.extract(QUERY_ITER_REGEX, 1)
.cast(pl.Int64, strict=False)
.alias("query_num"),
stripped.str.extract(QUERY_ITER_REGEX, 2)
.cast(pl.Int64, strict=False)
.alias("iteration_num"),
)
.filter(pl.col("query_num").is_not_null() & pl.col("iteration_num").is_not_null())
.select(
pl.col("start").cast(pl.Int64),
pl.col("end").cast(pl.Int64),
pl.col("query_num"),
pl.col("iteration_num"),
)
)
def filter_windows_for_cli(
qi: pl.DataFrame,
query_id: int | None,
iteration: int | None,
) -> tuple[pl.DataFrame | None, pl.DataFrame]:
"""
Returns (windows, wall_slices).
``wall_slices``: ``start``, ``end`` for all ``Query N - Iteration M`` rows in ``qi``.
If ``query_id`` / ``iteration`` are set: ``windows`` is a filtered copy of those
columns (may be empty). Otherwise ``windows`` is ``None`` (no filtering).
"""
wall = qi.select(pl.col("start"), pl.col("end"))
if query_id is not None or iteration is not None:
cond = pl.lit(True)
if query_id is not None:
cond = cond & (pl.col("query_num") == query_id)
if iteration is not None:
cond = cond & (pl.col("iteration_num") == iteration)
filtered = qi.filter(cond).select(pl.col("start"), pl.col("end"))
return filtered, filtered
return None, wall
def filter_qi_dataframe(
qi: pl.DataFrame,
query_id: int | None,
iteration: int | None,
) -> pl.DataFrame:
"""Filter ``Query N - Iteration M`` rows by optional ``query_id`` / ``iteration``."""
if query_id is None and iteration is None:
return qi
cond = pl.lit(True)
if query_id is not None:
cond = cond & (pl.col("query_num") == query_id)
if iteration is not None:
cond = cond & (pl.col("iteration_num") == iteration)
return qi.filter(cond)
def assign_query_iteration_to_events(df: pl.DataFrame, qi: pl.DataFrame) -> pl.DataFrame:
"""
Add ``query`` and ``iteration`` columns by matching each event to the
``Query N - Iteration M`` window with the largest time overlap (tie: lower query, iteration).
"""
if qi.is_empty():
return df.with_columns(
pl.lit(None).cast(pl.Int64).alias("query"),
pl.lit(None).cast(pl.Int64).alias("iteration"),
)
w = qi.select(
pl.col("start").alias("qs"),
pl.col("end").alias("qe"),
pl.col("query_num").alias("query"),
pl.col("iteration_num").alias("iteration"),
)
j = (
df.with_row_index("__rid")
.join(w, how="cross")
.filter((pl.col("start") < pl.col("qe")) & (pl.col("end") > pl.col("qs")))
.with_columns(
(
pl.min_horizontal(pl.col("end"), pl.col("qe"))
- pl.max_horizontal(pl.col("start"), pl.col("qs"))
).alias("_ov")
)
)
if j.is_empty():
return df.with_columns(
pl.lit(None).cast(pl.Int64).alias("query"),
pl.lit(None).cast(pl.Int64).alias("iteration"),
)
picked = (
j.sort(["__rid", "_ov", "query", "iteration"], descending=[False, True, False, False])
.unique(subset=["__rid"], keep="first")
.select("__rid", "query", "iteration")
)
return df.with_row_index("__rid").join(picked, on="__rid", how="left").drop("__rid")
def nvtx_overlaps_windows(
start: int,
end: int,
windows: pl.DataFrame | None,
) -> bool:
if windows is None:
return True
if windows.is_empty():
return False
return (
windows.filter((pl.col("start") < end) & (pl.col("end") > start)).height > 0
)
def filter_intervals_overlap_windows(
df: pl.DataFrame,
windows: pl.DataFrame | None,
) -> pl.DataFrame:
"""
Keep rows whose ``[start, end)`` overlaps at least one row in ``windows``
(same semantics as :func:`nvtx_overlaps_windows`, vectorized).
"""
if windows is None:
return df
if windows.is_empty() or df.is_empty():
return df.head(0)
w = windows.select(
pl.col("start").cast(pl.Int64).alias("_ws"),
pl.col("end").cast(pl.Int64).alias("_we"),
)
ev = df.with_row_index("_eid")
matched = (
ev.join(w, how="cross")
.filter((pl.col("start") < pl.col("_we")) & (pl.col("end") > pl.col("_ws")))
.select("_eid")
.unique()
)
return ev.join(matched, on="_eid", how="inner").drop("_eid")
# -----------------------------------------------------------------------------
# Analytics: operator summary
# -----------------------------------------------------------------------------
def _empty_device_by_operator() -> pl.DataFrame:
return pl.DataFrame(schema={"operator": pl.Utf8, "device_time_ns": pl.Int64})
def attribute_kernel_ms_innermost(
nvtx_ranges: pl.DataFrame,
kernels: pl.DataFrame,
*,
label_col: str = "label",
) -> pl.DataFrame:
"""
For each kernel row, attribute full kernel duration (k_end - k_start) to the
NVTX range label on the same globalTid that is innermost at launch: among
ranges with start <= launch_start < end, pick the range with maximum start.
``nvtx_ranges`` must include columns ``start``, ``end``, ``globalTid``, and
``label_col`` (default ``label``). ``kernels`` must include ``k_start``,
``k_end``, ``launch_start``, ``globalTid``.
Returns a DataFrame with columns ``operator`` and ``device_time_ns`` (one row
per distinct NVTX label that received kernel time).
"""
if nvtx_ranges.is_empty() or kernels.is_empty():
return _empty_device_by_operator()
nvtx = (
nvtx_ranges.filter(pl.col("end") > pl.col("start"))
.select(
pl.col("start").alias("n_start"),
pl.col("end").alias("n_end"),
pl.col("globalTid"),
pl.col(label_col).alias("nvtx_label"),
)
)
if nvtx.is_empty():
return _empty_device_by_operator()
k = kernels.filter(pl.col("k_end") > pl.col("k_start")).with_row_index("__kid")
cand = k.join(nvtx, on="globalTid", how="inner").filter(
(pl.col("n_start") <= pl.col("launch_start"))
& (pl.col("launch_start") < pl.col("n_end"))
)
if cand.is_empty():
return _empty_device_by_operator()
# Per kernel, keep the NVTX candidate with largest ``start`` (innermost).
picked = (
cand.sort(["__kid", "n_start"], descending=[False, True])
.unique(subset=["__kid"], keep="first")
.with_columns((pl.col("k_end") - pl.col("k_start")).alias("dur_ns"))
)
return (
picked.group_by("nvtx_label")
.agg(pl.col("dur_ns").sum().alias("device_time_ns"))
.rename({"nvtx_label": "operator"})
)
SUMMARY_ROW_COLS = [
"operator",
"count",
"host_time_ns",
"host_pct",
"wall_time_ns",
"wall_pct",
"device_time_ns",
"device_pct",
]
def _operator_summary_single_partition(
df: pl.DataFrame,
kernels: pl.DataFrame,
) -> pl.DataFrame:
"""
Operator summary for one partition of IR NVTX rows (already scoped). No query/iteration columns.
"""
if df.is_empty():
return pl.DataFrame(
{
"operator": ["Total"],
"count": [0],
"host_time_ns": [0],
"host_pct": [0.0],
"wall_time_ns": [0],
"wall_pct": [0.0],
"device_time_ns": [0],
"device_pct": [0.0],
}
)
wall_total_ns = int(df["end"].max() - df["start"].min())
wall_by_op = wall_union_ns_per_operator(df)
host_agg = (
df.group_by("label_clean")
.agg(
(pl.col("end") - pl.col("start")).sum().alias("host_time_ns"),
pl.len().alias("count"),
)
.rename({"label_clean": "operator"})
)
nvtx_for_device = df.select(
pl.col("start"),
pl.col("end"),
pl.col("globalTid"),
pl.col("label_clean").alias("label"),
)
dev_pl = attribute_kernel_ms_innermost(nvtx_for_device, kernels)
ops = pl.concat([host_agg.select("operator"), dev_pl.select("operator")]).unique()
merged = (
ops.join(host_agg, on="operator", how="left")
.join(dev_pl, on="operator", how="left")
.join(wall_by_op, on="operator", how="left")
.with_columns(
pl.col("host_time_ns").fill_null(0),
pl.col("device_time_ns").fill_null(0),
pl.col("count").fill_null(0),
pl.col("wall_union_ns").fill_null(0),
)
)
total_h = int(merged["host_time_ns"].sum())
total_d = int(merged["device_time_ns"].sum())
total_count = int(merged["count"].sum())
merged = merged.sort("operator").with_columns(
pl.when(pl.lit(total_h) > 0)
.then(pl.col("host_time_ns") / pl.lit(total_h) * 100.0)
.otherwise(0.0)
.alias("host_pct"),
pl.when(pl.lit(wall_total_ns) > 0)
.then(pl.col("wall_union_ns") / pl.lit(wall_total_ns) * 100.0)
.otherwise(0.0)
.alias("wall_pct"),
pl.when(pl.lit(total_d) > 0)
.then(pl.col("device_time_ns") / pl.lit(total_d) * 100.0)
.otherwise(0.0)
.alias("device_pct"),
)
merged = (
merged.select(
pl.col("operator"),
pl.col("count"),
pl.col("host_time_ns"),
pl.col("host_pct"),
pl.col("wall_union_ns").alias("wall_time_ns"),
pl.col("wall_pct"),
pl.col("device_time_ns"),
pl.col("device_pct"),
)
.with_columns(pl.col("count").cast(pl.Int64))
)
total_row = pl.DataFrame(
{
"operator": ["Total"],
"count": [total_count],
"host_time_ns": [total_h],
"host_pct": [100.0 if total_h else 0.0],
"wall_time_ns": [wall_total_ns],
"wall_pct": [100.0 if wall_total_ns else 0.0],
"device_time_ns": [total_d],
"device_pct": [100.0 if total_d else 0.0],
}
).select(SUMMARY_ROW_COLS)
return pl.concat([merged, total_row])
def compute_operator_summary(
nvtx: pl.DataFrame,
kernels: pl.DataFrame,
domain_ids: list[int],
event_type_ids: list[int],
qi: pl.DataFrame,
query_filter: int | None,
iteration_filter: int | None,
*,
all_types: bool,
exclude: frozenset[str],
) -> pl.DataFrame:
"""Operator summary grouped by ``query`` and ``iteration`` from ``Query N - Iteration M`` NVTX."""
df = nvtx.filter(
pl.col("domainId").is_in(domain_ids)
& pl.col("eventType").is_in(event_type_ids)
& pl.col("label").is_not_null()
& (pl.col("end") > pl.col("start"))
).with_columns(
pl.col("label").cast(pl.Utf8).str.strip_chars().alias("label_clean"),
)
df = df.filter(pl.col("label_clean").str.len_chars() > 0)
if not all_types:
df = df.filter(~pl.col("label_clean").is_in(list(exclude)))
df = df.filter(~pl.col("label_clean").str.contains(QUERY_ITER_REGEX))
qi_use = filter_qi_dataframe(qi, query_filter, iteration_filter)
summary_out_cols = ["query", "iteration", *SUMMARY_ROW_COLS]
if qi.is_empty():
part = _operator_summary_single_partition(df, kernels)
return part.with_columns(
pl.lit(None).cast(pl.Int64).alias("query"),
pl.lit(None).cast(pl.Int64).alias("iteration"),
).select(summary_out_cols)
df_qi = assign_query_iteration_to_events(df, qi_use)
df_qi = df_qi.filter(pl.col("query").is_not_null() & pl.col("iteration").is_not_null())
if df_qi.is_empty():
return pl.DataFrame(
schema={
"query": pl.Int64,
"iteration": pl.Int64,
"operator": pl.Utf8,
"count": pl.Int64,
"host_time_ns": pl.Int64,
"host_pct": pl.Float64,
"wall_time_ns": pl.Int64,
"wall_pct": pl.Float64,
"device_time_ns": pl.Int64,
"device_pct": pl.Float64,
}
)
def _add_qi_cols(g: pl.DataFrame) -> pl.DataFrame:
qv = int(g["query"][0])
iv = int(g["iteration"][0])
sub = g.drop("query", "iteration")
block = _operator_summary_single_partition(sub, kernels)
return block.with_columns(
pl.lit(qv).alias("query"),
pl.lit(iv).alias("iteration"),
).select(summary_out_cols)
out = df_qi.group_by(["query", "iteration"], maintain_order=False).map_groups(
_add_qi_cols
)
# Keep ``Total`` last within each (query, iteration); sorting only by ``operator``
# would place it alphabetically (e.g. before ``Union``).
return (
out.with_columns((pl.col("operator") == "Total").cast(pl.Int8).alias("_total_last"))
.sort(["query", "iteration", "_total_last", "operator"])
.drop("_total_last")
)
# -----------------------------------------------------------------------------
# Analytics: I/O (kvikio bytes + Scan)
# -----------------------------------------------------------------------------
def row_payload_bytes(
row: dict[str, Any],
schema: NsysSchema,
nvtx_cols: frozenset[str],
) -> int | None:
if schema.nvtx_uint64_col and "payload_raw" in row:
raw = row["payload_raw"]
if raw is not None:
return int(raw) & 0xFFFFFFFFFFFFFFFF
raw = row.get("payload_raw")
pt = row.get("payload_type")
pti = int(pt) if pt is not None else None
return decode_nvtx_uint64(raw, pti)
def add_payload_bytes_column(df: pl.DataFrame, schema: NsysSchema) -> pl.DataFrame:
"""Add ``payload_bytes`` (decoded byte count) per NVTX row; empty ``df`` unchanged shape."""
if df.is_empty():
return df.with_columns(pl.lit(None).cast(pl.Int64).alias("payload_bytes"))
if schema.nvtx_uint64_col:
# SQLite / Polars may load uint64Value as Float64; mask in UInt64 space, then Int64.
pr = pl.col("payload_raw")
mask = pl.lit(0xFFFFFFFFFFFFFFFF, dtype=pl.UInt64)
return df.with_columns(
pl.when(pr.is_not_null())
.then((pr.cast(pl.UInt64, strict=False) & mask).cast(pl.Int64))
.otherwise(pl.lit(None, dtype=pl.Int64))
.alias("payload_bytes")
)
return df.with_columns(
pl.struct(["payload_raw", "payload_type"])
.map_elements(
lambda r: decode_nvtx_uint64(r["payload_raw"], r["payload_type"]),
return_dtype=pl.Int64,
)
.alias("payload_bytes")
)
def compute_io_summary(
nvtx: pl.DataFrame,
kernels: pl.DataFrame,
kvikio_domain_ids: list[int],
cudf_domain_ids: list[int],
event_type_ids: list[int],
windows: pl.DataFrame | None,
nvtx_sqlite_cols: NsysSchema,
) -> dict[str, Any]:
kset = set(kvikio_domain_ids)
cset = set(cudf_domain_ids)
et_list = list(event_type_ids)
base = nvtx.filter(
pl.col("eventType").is_in(et_list) & (pl.col("end") > pl.col("start"))
)
base = filter_intervals_overlap_windows(base, windows)
label_txt = (
pl.col("label").cast(pl.Utf8).fill_null("").str.strip_chars().alias("_label_txt")
)
base = base.with_columns(label_txt)
kv = base.filter(
pl.col("domainId").is_in(list(kset))
& pl.col("_label_txt").str.contains(KVIKIO_READ_NAME_SUBSTR, literal=True)
)
kv = add_payload_bytes_column(kv, nvtx_sqlite_cols)
total_bytes = (
0
if kv.is_empty()
else int(kv.select(pl.col("payload_bytes").sum()).item() or 0)
)
scan_df = base.filter(
pl.col("domainId").is_in(list(cset)) & (pl.col("_label_txt") == "Scan")
).select(
pl.col("start").cast(pl.Int64),
pl.col("end").cast(pl.Int64),
pl.col("globalTid").cast(pl.Int64),
pl.col("_label_txt").alias("label"),
)
scan_count = scan_df.height
if scan_df.is_empty():
sum_scan_host_ns = 0
scan_union_wall_ns = 0
else:
sum_scan_host_ns = int(
scan_df.select((pl.col("end") - pl.col("start")).sum()).item()
)
scan_union_wall_ns = merge_intervals_union_ns(scan_df.select("start", "end"))
dev_by_op = attribute_kernel_ms_innermost(scan_df, kernels)
scan_row = dev_by_op.filter(pl.col("operator") == "Scan")
dev_scan_ns = int(scan_row["device_time_ns"][0]) if scan_row.height > 0 else 0
# Throughput: bytes / denominator (same total_bytes for all)
def tp(denom_ns: int) -> float | None:
if denom_ns <= 0:
return None
return float(total_bytes) / float(denom_ns)
return {
"scan_count": scan_count,
"total_bytes": total_bytes,
"sum_scan_host_ns": sum_scan_host_ns,
"scan_union_wall_ns": scan_union_wall_ns,
"sum_scan_device_ns": dev_scan_ns,
"throughput_bytes_per_ns_sum_host": tp(sum_scan_host_ns),
"throughput_bytes_per_ns_union_wall": tp(scan_union_wall_ns),
"throughput_bytes_per_ns_sum_device": tp(dev_scan_ns),
}
# -----------------------------------------------------------------------------
# Human-readable formatting (table output only)
# -----------------------------------------------------------------------------
def format_ns(ns: float) -> str:
if ns >= 1e9:
return f"{ns / 1e9:.3f} s"
if ns >= 1e6:
return f"{ns / 1e6:.3f} ms"
if ns >= 1e3:
return f"{ns / 1e3:.3f} µs"
return f"{ns:.0f} ns"
def format_bytes(n: int) -> str:
if n >= 1 << 30:
return f"{n / (1 << 30):.3f} GiB"
if n >= 1 << 20:
return f"{n / (1 << 20):.3f} MiB"
if n >= 1 << 10:
return f"{n / (1 << 10):.3f} KiB"
return str(n)
def format_throughput_table(bytes_count: int, denom_ns: int) -> str:
if denom_ns <= 0:
return "n/a"
bytes_per_s = bytes_count / denom_ns * 1e9
gib_s = bytes_per_s / (1024**3)
if gib_s >= 1.0:
return f"{gib_s:.4f} GiB/s"
mib_s = bytes_per_s / (1024**2)
return f"{mib_s:.4f} MiB/s"
# -----------------------------------------------------------------------------
# Render
# -----------------------------------------------------------------------------
def _fmt_qi_cell(v: object) -> str:
if v is None:
return "—"
return str(int(v))
def render_summary_table(console: Console, df: pl.DataFrame) -> None:
t = Table(show_header=True, header_style="bold")
t.add_column("Query")
t.add_column("Iter")
t.add_column("Operator")
t.add_column("Count", justify="right")
t.add_column("Host Time", justify="right")
t.add_column("Host %", justify="right")
t.add_column("Wall Time", justify="right")
t.add_column("Wall %", justify="right")
t.add_column("Device Time", justify="right")
t.add_column("Device %", justify="right")
for row in df.iter_rows(named=True):
op = row["operator"]
sty = "bold" if op == "Total" else None
t.add_row(
_fmt_qi_cell(row.get("query")),
_fmt_qi_cell(row.get("iteration")),
op,
str(int(row["count"])),
format_ns(float(row["host_time_ns"])),
f"{row['host_pct']:.1f}%",
format_ns(float(row["wall_time_ns"])),
f"{row['wall_pct']:.1f}%",
format_ns(float(row["device_time_ns"])),
f"{row['device_pct']:.1f}%",
style=sty,
)
console.print(t)
def render_io_table(console: Console, d: dict[str, Any]) -> None:
console.print("[bold]Scan[/bold]")
console.print(f" Count: {d['scan_count']}")
console.print(f" Bytes: {d['total_bytes']} ({format_bytes(int(d['total_bytes']))})")
console.print()
console.print("[bold]Durations[/bold]")
console.print(
f" Sum host durations: {format_ns(float(d['sum_scan_host_ns']))}"
)
console.print(
f" Union host durations (wall time): {format_ns(float(d['scan_union_wall_ns']))}"
)
# console.print(
# f" Sum of kernel time attributed to Scan: {format_ns(float(d['sum_scan_device_ns']))}"
# )
console.print()
console.print("[bold]Throughput (total_bytes / denominator)[/bold]")
tb = int(d["total_bytes"])
for label, key in (
("Host time", "sum_scan_host_ns"),
("Wall time", "scan_union_wall_ns"),
# ("Σ Scan device (kernels)", "sum_scan_device_ns"),
):
den = int(d[key])
tp = (tb / den) if den > 0 else None
tp_s = format_throughput_table(tb, den) if tp is not None else "n/a"
console.print(f" {label}: {tp_s}")
def df_to_json_records(df: pl.DataFrame) -> str:
return json.dumps(df.to_dicts(), indent=2)
def print_csv(df: pl.DataFrame) -> None:
print(df.write_csv())
# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------
def add_export_args(p: argparse.ArgumentParser) -> None:
p.add_argument(
"report",
type=Path,
help="Path to .nsys-rep",
)
p.add_argument(
"--sqlite",
type=Path,
default=None,
help="SQLite path (default: <report>.sqlite)",
)
p.add_argument(
"--no-export",
action="store_true",
help="Do not run nsys export; require existing SQLite",
)
p.add_argument(
"--force-export",
action="store_true",
help="Re-run nsys export even if SQLite exists",
)
p.add_argument(
"--domain",
default="cudf_polars",
help="NVTX domain for IR / Query ranges (default: cudf_polars)",
)
p.add_argument(
"--kvikio-domain",
default="libkvikio",
help="NVTX domain for kvikio reads (default: libkvikio)",
)
p.add_argument(
"--query",
type=int,
default=None,
help=(
"Restrict summary to this query index N (only these Query N - Iteration M "
"windows are used to label events; default: all queries)"
),
)
p.add_argument(
"--iteration",
type=int,
default=None,
help=(
"Restrict summary to this iteration M (only these Query N - Iteration M "
"windows are used to label events; default: all iterations)"
),
)
p.add_argument(
"--format",
choices=("table", "json", "csv"),
default="table",
help="Output format (table uses human units; json/csv use raw ns and bytes)",
)
p.add_argument(
"--parquet-cache",
dest="parquet_cache",
action="store_true",
default=True,
help="Write parquet extract cache (default: on)",
)
p.add_argument(
"--no-parquet-cache",
dest="parquet_cache",
action="store_false",
help="Do not write parquet cache",
)
p.add_argument(
"--use-cache",
dest="use_cache",
action="store_true",
default=True,
help="Load from parquet cache when fresh (default: on)",
)
p.add_argument(
"--no-use-cache",
dest="use_cache",
action="store_false",
help="Always read SQLite",
)
SUMMARY_HELP_EPILOG = """
Time columns (summary):
Host Sum of NVTX range lengths (end - start) for each range tagged with that
operator on CPU threads. Parent and child ranges can overlap in time, so
sums can exceed wall-clock span. The Host percent column is the share of
that summed time within the Query/Iteration group.
Wall Union of those intervals on one global timeline (overlaps merged). The Wall
percent column is union length divided by the span from the earliest IR
start to the latest IR end in the same group (calendar coverage).
Device Sum of GPU kernel durations attributed to this operator: each kernel is
assigned to the innermost active NVTX range on the launching CPU thread at
cuda launch. The Device percent column is the share of summed kernel time
in the group.
"""
def cmd_summary(args: argparse.Namespace) -> int:
p = argparse.ArgumentParser(
description=(
"Per-operator table grouped by Query / Iteration from NVTX. "
"Shows host (CPU NVTX), wall (union) coverage, and device (GPU) time."
),
epilog=SUMMARY_HELP_EPILOG,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
p.add_argument(
"--all-types",
action="store_true",
help="Include every NVTX label in the domain (not only known IR names)",
)
p.add_argument(
"--exclude",
default=",".join(sorted(DEFAULT_EXCLUDE_LABELS)),
help="Comma-separated labels to skip (default: ConvertIR,ExecuteIR)",
)
add_export_args(p)
ns = p.parse_args(args._remainder)
nsys_rep = ns.report.expanduser().resolve()
if not nsys_rep.is_file():
print(f"Report not found: {nsys_rep}", file=sys.stderr)
return 1
sqlite_path = (ns.sqlite or default_sqlite_path(nsys_rep)).expanduser().resolve()
if not ns.no_export:
ensure_nsys_available()
try:
run_nsys_export(nsys_rep, sqlite_path, force=ns.force_export)
except subprocess.CalledProcessError as e:
print(f"nsys export failed (exit {e.returncode})", file=sys.stderr)
return e.returncode or 1
elif not sqlite_path.is_file():
print(f"SQLite not found (--no-export): {sqlite_path}", file=sys.stderr)
return 1
segment = cache_key_segment(ns.query, ns.iteration)
ppaths = parquet_cache_paths(nsys_rep, segment)
conn = sqlite3.connect(f"file:{sqlite_path}?mode=ro", uri=True)
try:
et_ids = nvtx_range_event_type_ids(conn)
if not et_ids:
print("No NVTX range event types in export.", file=sys.stderr)
return 1
cached = maybe_read_parquet_cache(ppaths, sqlite_path, use_cache=ns.use_cache)
if cached is not None:
nvtx, enum_t, kr = cached
else:
nvtx, enum_t, kr = load_extracted_tables(conn, et_ids)
if ns.parquet_cache:
write_parquet_cache(ppaths, nvtx, enum_t, kr)
domain_ids = domain_ids_for_name(conn, ns.domain)
if not domain_ids:
print(
f"No NvtxDomainCreate for domain {ns.domain!r}.",
file=sys.stderr,
)
return 1
qi = load_query_iteration_windows(nvtx, domain_ids, et_ids)
qi_filtered = filter_qi_dataframe(qi, ns.query, ns.iteration)
if (ns.query is not None or ns.iteration is not None) and qi_filtered.is_empty():
print(
"No NVTX ranges matched --query/--iteration.",
file=sys.stderr,
)
return 1
exclude = frozenset(
x.strip() for x in ns.exclude.split(",") if x.strip()
)
df = compute_operator_summary(
nvtx,
kr,
domain_ids,
et_ids,
qi,
query_filter=ns.query,
iteration_filter=ns.iteration,
all_types=ns.all_types,
exclude=exclude,
)
finally:
conn.close()
console = Console()
if ns.format == "table":
render_summary_table(console, df)
elif ns.format == "json":
print(df_to_json_records(df))
else:
print_csv(df)
return 0
def cmd_io(args: argparse.Namespace) -> int:
p = argparse.ArgumentParser()
add_export_args(p)
ns = p.parse_args(args._remainder)
nsys_rep = ns.report.expanduser().resolve()
if not nsys_rep.is_file():
print(f"Report not found: {nsys_rep}", file=sys.stderr)
return 1
sqlite_path = (ns.sqlite or default_sqlite_path(nsys_rep)).expanduser().resolve()
if not ns.no_export:
ensure_nsys_available()
try:
run_nsys_export(nsys_rep, sqlite_path, force=ns.force_export)
except subprocess.CalledProcessError as e:
print(f"nsys export failed (exit {e.returncode})", file=sys.stderr)
return e.returncode or 1
elif not sqlite_path.is_file():
print(f"SQLite not found (--no-export): {sqlite_path}", file=sys.stderr)
return 1
segment = cache_key_segment(ns.query, ns.iteration)
ppaths = parquet_cache_paths(nsys_rep, segment)
conn = sqlite3.connect(f"file:{sqlite_path}?mode=ro", uri=True)
try:
cols = sqlite_table_columns(conn, "NVTX_EVENTS")
nvtx_schema = detect_nvtx_payload_schema(cols)
et_ids = nvtx_range_event_type_ids(conn)
if not et_ids:
print("No NVTX range event types in export.", file=sys.stderr)
return 1
cached = maybe_read_parquet_cache(ppaths, sqlite_path, use_cache=ns.use_cache)
if cached is not None:
nvtx, _enum_t, kr = cached
else:
nvtx, _enum_t, kr = load_extracted_tables(conn, et_ids)
if ns.parquet_cache:
write_parquet_cache(ppaths, nvtx, _enum_t, kr)
cudf_ids = domain_ids_for_name(conn, ns.domain)
kvikio_ids = domain_ids_for_name(conn, ns.kvikio_domain)
if not cudf_ids:
print(f"No domain {ns.domain!r}.", file=sys.stderr)
return 1
qi = load_query_iteration_windows(nvtx, cudf_ids, et_ids)
windows, _ = filter_windows_for_cli(qi, ns.query, ns.iteration)
if (
(ns.query is not None or ns.iteration is not None)
and windows is not None
and windows.is_empty()
):
print(
"No NVTX ranges matched --query/--iteration.",
file=sys.stderr,
)
return 1
summary = compute_io_summary(
nvtx,
kr,
kvikio_ids,
cudf_ids,
et_ids,
windows,
nvtx_schema,
)
finally:
conn.close()
console = Console()
if ns.format == "table":
render_io_table(console, summary)
elif ns.format == "json":
print(json.dumps(summary, indent=2))
else:
import io
buf = io.StringIO()
pl.DataFrame([summary]).write_csv(buf)
print(buf.getvalue())
return 0
def main(argv: list[str] | None = None) -> int:
argv = argv if argv is not None else sys.argv[1:]
if not argv or argv[0] in ("-h", "--help"):
parser = argparse.ArgumentParser(
description=(
"cudf-polars Nsight Systems report analyzer (cpr). "
"Run 'summary --help' for Host vs Wall vs Device time."
),
)
sub = parser.add_subparsers(dest="command", required=True)
sp = sub.add_parser(
"summary",
help="Operator summary by Query/Iteration (host, wall, device time)",
)
sp.add_argument("remainder", nargs=argparse.REMAINDER, default=[], help=argparse.SUPPRESS)
io_p = sub.add_parser("io", help="I/O summary (Scan + kvikio reads)")
io_p.add_argument("remainder", nargs=argparse.REMAINDER, default=[], help=argparse.SUPPRESS)
sp.set_defaults(func=lambda a: cmd_summary(a))
io_p.set_defaults(func=lambda a: cmd_io(a))
parser.parse_args(argv if argv else ["-h"])
return 0
cmd = argv[0]
rest = argv[1:]
if cmd == "summary":
ns = argparse.Namespace(_remainder=rest)
return cmd_summary(ns)
if cmd == "io":
ns = argparse.Namespace(_remainder=rest)
return cmd_io(ns)
print(f"Unknown command: {cmd}", file=sys.stderr)
return 2
if __name__ == "__main__":
raise SystemExit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment