Created
April 8, 2026 15:54
-
-
Save TomAugspurger/355766e1ca9e0a611d3aadeb556b18a3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # /// script | |
| # requires-python = ">=3.11" | |
| # dependencies = [ | |
| # "polars>=1.0.0", | |
| # "numpy>=1.24", | |
| # "rich>=13", | |
| # "pyarrow>=14", | |
| # ] | |
| # /// | |
| # | |
| # SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| """ | |
| cudf-polars-report analyzer. | |
| This tool analyzes Nsight Systems reports (.nsys-rep) of cudf-polars' benchmark runs. | |
| uv run cpr.py summary report.nsys-rep | |
| uv run cpr.py io report.nsys-rep --query 0 --iteration 1 | |
| Requires ``nsys`` on PATH to create the SQLite export unless one already exists. | |
| See ``cpr.py summary --help`` for how Host, Wall, and Device times differ. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import shutil | |
| import numpy as np | |
| import sqlite3 | |
| import struct | |
| import subprocess | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| import polars as pl | |
| from rich.console import Console | |
| from rich.table import Table | |
| # ----------------------------------------------------------------------------- | |
| # Constants | |
| # ----------------------------------------------------------------------------- | |
| # Case-insensitive; full-line match for Query N - Iteration M (polars/rust regex) | |
| QUERY_ITER_REGEX = r"(?i)^Query\s+(\d+)\s+-\s+Iteration\s+(\d+)\s*$" | |
| DEFAULT_EXCLUDE_LABELS: frozenset[str] = frozenset({"ConvertIR", "ExecuteIR"}) | |
| # nvtxPayloadType_t (nvToolsExt.h) | |
| NVTX_PAYLOAD_UNSIGNED_INT64 = 1 | |
| NVTX_PAYLOAD_INT64 = 2 | |
| KVIKIO_READ_NAME_SUBSTR = "RemoteHandle::read" | |
| CACHE_TABLES = ("nvtx_events", "enum_nsys_event_type", "kernels_runtime") | |
| # ----------------------------------------------------------------------------- | |
| # nsys export | |
| # ----------------------------------------------------------------------------- | |
| def default_sqlite_path(nsys_rep: Path) -> Path: | |
| return nsys_rep.with_suffix(".sqlite") | |
| def run_nsys_export(nsys_rep: Path, sqlite_out: Path, *, force: bool) -> None: | |
| if sqlite_out.exists() and not force: | |
| return | |
| sqlite_out.parent.mkdir(parents=True, exist_ok=True) | |
| cmd = [ | |
| "nsys", | |
| "export", | |
| "--type", | |
| "sqlite", | |
| "--force-overwrite=true", | |
| "-q", | |
| "true", | |
| "-o", | |
| str(sqlite_out), | |
| str(nsys_rep), | |
| ] | |
| subprocess.run(cmd, check=True) | |
| def ensure_nsys_available() -> None: | |
| if shutil.which("nsys") is None: | |
| print( | |
| "error: 'nsys' not found on PATH. Install Nsight Systems CLI.", | |
| file=sys.stderr, | |
| ) | |
| raise SystemExit(127) | |
| # ----------------------------------------------------------------------------- | |
| # SQLite helpers | |
| # ----------------------------------------------------------------------------- | |
| def sqlite_table_columns(conn: sqlite3.Connection, table: str) -> frozenset[str]: | |
| cur = conn.execute(f"PRAGMA table_info({table})") | |
| return frozenset(str(row[1]) for row in cur.fetchall()) | |
| def query_polars(conn: sqlite3.Connection, sql: str, params: tuple[Any, ...] = ()) -> pl.DataFrame: | |
| cur = conn.execute(sql, params) | |
| if cur.description is None: | |
| return pl.DataFrame() | |
| cols = [d[0] for d in cur.description] | |
| rows = cur.fetchall() | |
| if not rows: | |
| return pl.DataFrame({c: [] for c in cols}) | |
| return pl.DataFrame({cols[i]: [r[i] for r in rows] for i in range(len(cols))}) | |
| # ----------------------------------------------------------------------------- | |
| # NVTX payload decoding (UINT64 byte counts from kvikio) | |
| # ----------------------------------------------------------------------------- | |
| def decode_nvtx_uint64( | |
| raw: object, | |
| payload_type: int | None, | |
| ) -> int | None: | |
| if raw is None: | |
| return None | |
| if payload_type is not None and payload_type not in ( | |
| 0, | |
| NVTX_PAYLOAD_UNSIGNED_INT64, | |
| NVTX_PAYLOAD_INT64, | |
| ): | |
| return None | |
| if isinstance(raw, int): | |
| if raw < 0 and payload_type == NVTX_PAYLOAD_INT64: | |
| return None | |
| return int(raw) & 0xFFFFFFFFFFFFFFFF | |
| if isinstance(raw, (bytes, bytearray, memoryview)): | |
| buf = bytes(raw) | |
| if len(buf) >= 8: | |
| return struct.unpack_from("<Q", buf, 0)[0] | |
| if len(buf) >= 4: | |
| return int(struct.unpack_from("<I", buf, 0)[0]) | |
| return None | |
| # ----------------------------------------------------------------------------- | |
| # Interval helpers | |
| # ----------------------------------------------------------------------------- | |
| def _interval_union_ns_sorted_arrays(s: np.ndarray, e: np.ndarray) -> int: | |
| """Union length of ``[s[i], e[i])`` with rows sorted by ``s`` ascending.""" | |
| if s.size == 0: | |
| return 0 | |
| total = 0 | |
| cs, ce = int(s[0]), int(e[0]) | |
| for i in range(1, s.size): | |
| si, ei = int(s[i]), int(e[i]) | |
| if si > ce: | |
| total += ce - cs | |
| cs, ce = si, ei | |
| else: | |
| ce = max(ce, ei) | |
| total += ce - cs | |
| return total | |
| def merge_intervals_union_ns(iv: pl.DataFrame) -> int: | |
| """ | |
| Total nanoseconds in the union of ``[start, end)`` intervals. | |
| Expects columns ``start`` and ``end`` (cast to ``Int64``). Uses sorted merge | |
| with NumPy arrays from Polars (no ``.to_list()``). | |
| """ | |
| df = ( | |
| iv.select(pl.col("start").cast(pl.Int64), pl.col("end").cast(pl.Int64)) | |
| .filter(pl.col("end") > pl.col("start")) | |
| .sort("start") | |
| ) | |
| if df.is_empty(): | |
| return 0 | |
| s = df.get_column("start").to_numpy() | |
| e = df.get_column("end").to_numpy() | |
| return _interval_union_ns_sorted_arrays(s, e) | |
| def wall_union_ns_per_operator(df: pl.DataFrame) -> pl.DataFrame: | |
| """ | |
| Per ``label_clean``, nanoseconds covered by the union of ``[start, end)`` | |
| intervals on the global timeline (all instances of that operator merged). | |
| """ | |
| if df.is_empty(): | |
| return pl.DataFrame(schema={"operator": pl.Utf8, "wall_union_ns": pl.Int64}) | |
| return ( | |
| df.group_by("label_clean", maintain_order=False) | |
| .map_groups( | |
| lambda g: pl.DataFrame( | |
| { | |
| "operator": [str(g.get_column("label_clean")[0])], | |
| "wall_union_ns": [merge_intervals_union_ns(g.select("start", "end"))], | |
| } | |
| ) | |
| ) | |
| .select(pl.col("operator"), pl.col("wall_union_ns")) | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # Extracted data loading (sqlite -> polars), optional parquet cache | |
| # ----------------------------------------------------------------------------- | |
| @dataclass(frozen=True) | |
| class NsysSchema: | |
| nvtx_uint64_col: str | None # "uint64Value" or None | |
| nvtx_payload_col: str | None # "payload" | |
| nvtx_payload_type_col: str | None | |
| def detect_nvtx_payload_schema(cols: frozenset[str]) -> NsysSchema: | |
| if "uint64Value" in cols: | |
| return NsysSchema( | |
| nvtx_uint64_col="uint64Value", | |
| nvtx_payload_col=None, | |
| nvtx_payload_type_col=None, | |
| ) | |
| p = "payload" if "payload" in cols else None | |
| pt = "payloadType" if "payloadType" in cols else None | |
| return NsysSchema(None, p, pt) | |
| def nvtx_range_event_type_ids(conn: sqlite3.Connection) -> list[int]: | |
| df = query_polars( | |
| conn, | |
| """ | |
| SELECT id FROM ENUM_NSYS_EVENT_TYPE | |
| WHERE name IN ( | |
| 'NvtxPushPopRange', | |
| 'NvtxStartEndRange', | |
| 'NvtxtPushPopRange', | |
| 'NvtxtStartEndRange' | |
| ) | |
| """, | |
| ) | |
| if df.is_empty(): | |
| return [] | |
| return df["id"].cast(pl.Int64).to_list() | |
| def domain_create_event_type_id(conn: sqlite3.Connection) -> int | None: | |
| df = query_polars( | |
| conn, | |
| "SELECT id FROM ENUM_NSYS_EVENT_TYPE WHERE name = 'NvtxDomainCreate' LIMIT 1", | |
| ) | |
| if df.is_empty(): | |
| return None | |
| return int(df["id"][0]) | |
| def domain_ids_for_name(conn: sqlite3.Connection, domain_name: str) -> list[int]: | |
| dcid = domain_create_event_type_id(conn) | |
| if dcid is None: | |
| return [] | |
| df = query_polars( | |
| conn, | |
| """ | |
| SELECT DISTINCT n.domainId | |
| FROM NVTX_EVENTS n | |
| LEFT JOIN StringIds t ON n.textId = t.id | |
| WHERE n.eventType = ? | |
| AND COALESCE(t.value, n.text) = ? | |
| """, | |
| (dcid, domain_name), | |
| ) | |
| if df.is_empty(): | |
| return [] | |
| return df["domainId"].cast(pl.Int64).unique().to_list() | |
| def build_nvtx_extract_sql(schema: NsysSchema) -> str: | |
| if schema.nvtx_uint64_col: | |
| payload_sel = "n.uint64Value AS payload_raw" | |
| ptype_sel = "CAST(NULL AS INTEGER) AS payload_type" | |
| elif schema.nvtx_payload_col: | |
| payload_sel = f"n.{schema.nvtx_payload_col} AS payload_raw" | |
| ptype_sel = ( | |
| f"n.{schema.nvtx_payload_type_col} AS payload_type" | |
| if schema.nvtx_payload_type_col | |
| else "CAST(NULL AS INTEGER) AS payload_type" | |
| ) | |
| else: | |
| payload_sel = "CAST(NULL AS BLOB) AS payload_raw" | |
| ptype_sel = "CAST(NULL AS INTEGER) AS payload_type" | |
| return f""" | |
| SELECT n.start, n.end, n.globalTid, n.domainId, n.eventType, | |
| COALESCE(t.value, n.text) AS label, | |
| {payload_sel}, {ptype_sel} | |
| FROM NVTX_EVENTS n | |
| LEFT JOIN StringIds t ON n.textId = t.id | |
| WHERE n.eventType IN ({{et_placeholders}}) | |
| AND n.end IS NOT NULL | |
| """ | |
| def load_enum_table(conn: sqlite3.Connection) -> pl.DataFrame: | |
| return query_polars(conn, "SELECT id, name FROM ENUM_NSYS_EVENT_TYPE") | |
| def load_nvtx_events_extract( | |
| conn: sqlite3.Connection, | |
| event_type_ids: list[int], | |
| ) -> pl.DataFrame: | |
| if not event_type_ids: | |
| return pl.DataFrame( | |
| schema={ | |
| "start": pl.Int64, | |
| "end": pl.Int64, | |
| "globalTid": pl.Int64, | |
| "domainId": pl.Int64, | |
| "eventType": pl.Int64, | |
| "label": pl.Utf8, | |
| "payload_raw": pl.Binary, | |
| "payload_type": pl.Int64, | |
| } | |
| ) | |
| cols = sqlite_table_columns(conn, "NVTX_EVENTS") | |
| schema = detect_nvtx_payload_schema(cols) | |
| ph = ",".join("?" * len(event_type_ids)) | |
| sql_template = build_nvtx_extract_sql(schema) | |
| sql = sql_template.replace("{et_placeholders}", ph) | |
| df = query_polars(conn, sql, tuple(event_type_ids)) | |
| # Normalize dtypes | |
| for name in ("start", "end", "globalTid", "domainId", "eventType"): | |
| if name in df.columns: | |
| df = df.with_columns(pl.col(name).cast(pl.Int64)) | |
| if "payload_type" in df.columns: | |
| df = df.with_columns(pl.col("payload_type").cast(pl.Int64)) | |
| return df | |
| def launch_name_ids(conn: sqlite3.Connection) -> list[int]: | |
| df = query_polars( | |
| conn, | |
| """ | |
| SELECT id FROM StringIds | |
| WHERE value LIKE 'cudaLaunchKernel%' OR value LIKE 'cudaLaunch%' | |
| """, | |
| ) | |
| if df.is_empty(): | |
| return [] | |
| return df["id"].cast(pl.Int64).to_list() | |
| def load_kernels_runtime(conn: sqlite3.Connection) -> pl.DataFrame: | |
| """Kernel intervals joined to runtime launch rows (for globalTid and launch time).""" | |
| lids = launch_name_ids(conn) | |
| if not lids: | |
| return pl.DataFrame( | |
| schema={ | |
| "k_start": pl.Int64, | |
| "k_end": pl.Int64, | |
| "launch_start": pl.Int64, | |
| "globalTid": pl.Int64, | |
| } | |
| ) | |
| if not sqlite_table_columns(conn, "CUPTI_ACTIVITY_KIND_KERNEL"): | |
| return pl.DataFrame( | |
| schema={ | |
| "k_start": pl.Int64, | |
| "k_end": pl.Int64, | |
| "launch_start": pl.Int64, | |
| "globalTid": pl.Int64, | |
| } | |
| ) | |
| ph = ",".join("?" * len(lids)) | |
| sql = f""" | |
| SELECT k.start AS k_start, k.end AS k_end, r.start AS launch_start, r.globalTid | |
| FROM CUPTI_ACTIVITY_KIND_KERNEL k | |
| INNER JOIN CUPTI_ACTIVITY_KIND_RUNTIME r | |
| ON k.correlationId = r.correlationId | |
| WHERE r.nameId IN ({ph}) | |
| AND r.globalTid IS NOT NULL | |
| """ | |
| try: | |
| df = query_polars(conn, sql, tuple(lids)) | |
| except Exception: | |
| return pl.DataFrame( | |
| schema={ | |
| "k_start": pl.Int64, | |
| "k_end": pl.Int64, | |
| "launch_start": pl.Int64, | |
| "globalTid": pl.Int64, | |
| } | |
| ) | |
| if df.is_empty(): | |
| return df | |
| for name in ("k_start", "k_end", "launch_start", "globalTid"): | |
| df = df.with_columns(pl.col(name).cast(pl.Int64)) | |
| return df | |
| def cache_key_segment(query: int | None, iteration: int | None) -> str: | |
| q = "all" if query is None else str(query) | |
| i = "all" if iteration is None else str(iteration) | |
| return f"{q}-{i}" | |
| def parquet_cache_paths( | |
| nsys_rep: Path, | |
| segment: str, | |
| ) -> dict[str, Path]: | |
| base = nsys_rep.parent / f"{nsys_rep.stem}.parquet" | |
| return {t: base / f"{segment}-{t}.parquet" for t in CACHE_TABLES} | |
| def sqlite_mtime(path: Path) -> float: | |
| return path.stat().st_mtime | |
| def cache_is_fresh(sqlite_path: Path, paths: dict[str, Path]) -> bool: | |
| if not sqlite_path.is_file(): | |
| return False | |
| sm = sqlite_mtime(sqlite_path) | |
| for p in paths.values(): | |
| if not p.is_file(): | |
| return False | |
| if sqlite_mtime(p) < sm: | |
| return False | |
| return True | |
| def load_extracted_tables( | |
| conn: sqlite3.Connection, | |
| event_type_ids: list[int], | |
| ) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: | |
| nvtx = load_nvtx_events_extract(conn, event_type_ids) | |
| enum_t = load_enum_table(conn) | |
| kr = load_kernels_runtime(conn) | |
| return nvtx, enum_t, kr | |
| def maybe_read_parquet_cache( | |
| paths: dict[str, Path], | |
| sqlite_path: Path, | |
| *, | |
| use_cache: bool, | |
| ) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame] | None: | |
| if not use_cache: | |
| return None | |
| if not cache_is_fresh(sqlite_path, paths): | |
| return None | |
| return ( | |
| pl.read_parquet(paths["nvtx_events"]), | |
| pl.read_parquet(paths["enum_nsys_event_type"]), | |
| pl.read_parquet(paths["kernels_runtime"]), | |
| ) | |
| def write_parquet_cache( | |
| paths: dict[str, Path], | |
| nvtx: pl.DataFrame, | |
| enum_t: pl.DataFrame, | |
| kr: pl.DataFrame, | |
| ) -> None: | |
| paths["nvtx_events"].parent.mkdir(parents=True, exist_ok=True) | |
| nvtx.write_parquet(paths["nvtx_events"]) | |
| enum_t.write_parquet(paths["enum_nsys_event_type"]) | |
| kr.write_parquet(paths["kernels_runtime"]) | |
| # ----------------------------------------------------------------------------- | |
| # Query / iteration windows | |
| # ----------------------------------------------------------------------------- | |
| def load_query_iteration_windows( | |
| nvtx: pl.DataFrame, | |
| domain_ids: list[int], | |
| event_type_ids: list[int], | |
| ) -> pl.DataFrame: | |
| """Rows for ``Query N - Iteration M`` NVTX ranges: start, end, query_num, iteration_num.""" | |
| empty = pl.DataFrame( | |
| schema={ | |
| "start": pl.Int64, | |
| "end": pl.Int64, | |
| "query_num": pl.Int64, | |
| "iteration_num": pl.Int64, | |
| } | |
| ) | |
| if not domain_ids or not event_type_ids or nvtx.is_empty(): | |
| return empty | |
| stripped = pl.col("label").cast(pl.Utf8).str.strip_chars() | |
| return ( | |
| nvtx.filter( | |
| pl.col("domainId").is_in(domain_ids) | |
| & pl.col("eventType").is_in(event_type_ids) | |
| & pl.col("label").is_not_null() | |
| ) | |
| .with_columns( | |
| stripped.str.extract(QUERY_ITER_REGEX, 1) | |
| .cast(pl.Int64, strict=False) | |
| .alias("query_num"), | |
| stripped.str.extract(QUERY_ITER_REGEX, 2) | |
| .cast(pl.Int64, strict=False) | |
| .alias("iteration_num"), | |
| ) | |
| .filter(pl.col("query_num").is_not_null() & pl.col("iteration_num").is_not_null()) | |
| .select( | |
| pl.col("start").cast(pl.Int64), | |
| pl.col("end").cast(pl.Int64), | |
| pl.col("query_num"), | |
| pl.col("iteration_num"), | |
| ) | |
| ) | |
| def filter_windows_for_cli( | |
| qi: pl.DataFrame, | |
| query_id: int | None, | |
| iteration: int | None, | |
| ) -> tuple[pl.DataFrame | None, pl.DataFrame]: | |
| """ | |
| Returns (windows, wall_slices). | |
| ``wall_slices``: ``start``, ``end`` for all ``Query N - Iteration M`` rows in ``qi``. | |
| If ``query_id`` / ``iteration`` are set: ``windows`` is a filtered copy of those | |
| columns (may be empty). Otherwise ``windows`` is ``None`` (no filtering). | |
| """ | |
| wall = qi.select(pl.col("start"), pl.col("end")) | |
| if query_id is not None or iteration is not None: | |
| cond = pl.lit(True) | |
| if query_id is not None: | |
| cond = cond & (pl.col("query_num") == query_id) | |
| if iteration is not None: | |
| cond = cond & (pl.col("iteration_num") == iteration) | |
| filtered = qi.filter(cond).select(pl.col("start"), pl.col("end")) | |
| return filtered, filtered | |
| return None, wall | |
| def filter_qi_dataframe( | |
| qi: pl.DataFrame, | |
| query_id: int | None, | |
| iteration: int | None, | |
| ) -> pl.DataFrame: | |
| """Filter ``Query N - Iteration M`` rows by optional ``query_id`` / ``iteration``.""" | |
| if query_id is None and iteration is None: | |
| return qi | |
| cond = pl.lit(True) | |
| if query_id is not None: | |
| cond = cond & (pl.col("query_num") == query_id) | |
| if iteration is not None: | |
| cond = cond & (pl.col("iteration_num") == iteration) | |
| return qi.filter(cond) | |
| def assign_query_iteration_to_events(df: pl.DataFrame, qi: pl.DataFrame) -> pl.DataFrame: | |
| """ | |
| Add ``query`` and ``iteration`` columns by matching each event to the | |
| ``Query N - Iteration M`` window with the largest time overlap (tie: lower query, iteration). | |
| """ | |
| if qi.is_empty(): | |
| return df.with_columns( | |
| pl.lit(None).cast(pl.Int64).alias("query"), | |
| pl.lit(None).cast(pl.Int64).alias("iteration"), | |
| ) | |
| w = qi.select( | |
| pl.col("start").alias("qs"), | |
| pl.col("end").alias("qe"), | |
| pl.col("query_num").alias("query"), | |
| pl.col("iteration_num").alias("iteration"), | |
| ) | |
| j = ( | |
| df.with_row_index("__rid") | |
| .join(w, how="cross") | |
| .filter((pl.col("start") < pl.col("qe")) & (pl.col("end") > pl.col("qs"))) | |
| .with_columns( | |
| ( | |
| pl.min_horizontal(pl.col("end"), pl.col("qe")) | |
| - pl.max_horizontal(pl.col("start"), pl.col("qs")) | |
| ).alias("_ov") | |
| ) | |
| ) | |
| if j.is_empty(): | |
| return df.with_columns( | |
| pl.lit(None).cast(pl.Int64).alias("query"), | |
| pl.lit(None).cast(pl.Int64).alias("iteration"), | |
| ) | |
| picked = ( | |
| j.sort(["__rid", "_ov", "query", "iteration"], descending=[False, True, False, False]) | |
| .unique(subset=["__rid"], keep="first") | |
| .select("__rid", "query", "iteration") | |
| ) | |
| return df.with_row_index("__rid").join(picked, on="__rid", how="left").drop("__rid") | |
| def nvtx_overlaps_windows( | |
| start: int, | |
| end: int, | |
| windows: pl.DataFrame | None, | |
| ) -> bool: | |
| if windows is None: | |
| return True | |
| if windows.is_empty(): | |
| return False | |
| return ( | |
| windows.filter((pl.col("start") < end) & (pl.col("end") > start)).height > 0 | |
| ) | |
| def filter_intervals_overlap_windows( | |
| df: pl.DataFrame, | |
| windows: pl.DataFrame | None, | |
| ) -> pl.DataFrame: | |
| """ | |
| Keep rows whose ``[start, end)`` overlaps at least one row in ``windows`` | |
| (same semantics as :func:`nvtx_overlaps_windows`, vectorized). | |
| """ | |
| if windows is None: | |
| return df | |
| if windows.is_empty() or df.is_empty(): | |
| return df.head(0) | |
| w = windows.select( | |
| pl.col("start").cast(pl.Int64).alias("_ws"), | |
| pl.col("end").cast(pl.Int64).alias("_we"), | |
| ) | |
| ev = df.with_row_index("_eid") | |
| matched = ( | |
| ev.join(w, how="cross") | |
| .filter((pl.col("start") < pl.col("_we")) & (pl.col("end") > pl.col("_ws"))) | |
| .select("_eid") | |
| .unique() | |
| ) | |
| return ev.join(matched, on="_eid", how="inner").drop("_eid") | |
| # ----------------------------------------------------------------------------- | |
| # Analytics: operator summary | |
| # ----------------------------------------------------------------------------- | |
| def _empty_device_by_operator() -> pl.DataFrame: | |
| return pl.DataFrame(schema={"operator": pl.Utf8, "device_time_ns": pl.Int64}) | |
| def attribute_kernel_ms_innermost( | |
| nvtx_ranges: pl.DataFrame, | |
| kernels: pl.DataFrame, | |
| *, | |
| label_col: str = "label", | |
| ) -> pl.DataFrame: | |
| """ | |
| For each kernel row, attribute full kernel duration (k_end - k_start) to the | |
| NVTX range label on the same globalTid that is innermost at launch: among | |
| ranges with start <= launch_start < end, pick the range with maximum start. | |
| ``nvtx_ranges`` must include columns ``start``, ``end``, ``globalTid``, and | |
| ``label_col`` (default ``label``). ``kernels`` must include ``k_start``, | |
| ``k_end``, ``launch_start``, ``globalTid``. | |
| Returns a DataFrame with columns ``operator`` and ``device_time_ns`` (one row | |
| per distinct NVTX label that received kernel time). | |
| """ | |
| if nvtx_ranges.is_empty() or kernels.is_empty(): | |
| return _empty_device_by_operator() | |
| nvtx = ( | |
| nvtx_ranges.filter(pl.col("end") > pl.col("start")) | |
| .select( | |
| pl.col("start").alias("n_start"), | |
| pl.col("end").alias("n_end"), | |
| pl.col("globalTid"), | |
| pl.col(label_col).alias("nvtx_label"), | |
| ) | |
| ) | |
| if nvtx.is_empty(): | |
| return _empty_device_by_operator() | |
| k = kernels.filter(pl.col("k_end") > pl.col("k_start")).with_row_index("__kid") | |
| cand = k.join(nvtx, on="globalTid", how="inner").filter( | |
| (pl.col("n_start") <= pl.col("launch_start")) | |
| & (pl.col("launch_start") < pl.col("n_end")) | |
| ) | |
| if cand.is_empty(): | |
| return _empty_device_by_operator() | |
| # Per kernel, keep the NVTX candidate with largest ``start`` (innermost). | |
| picked = ( | |
| cand.sort(["__kid", "n_start"], descending=[False, True]) | |
| .unique(subset=["__kid"], keep="first") | |
| .with_columns((pl.col("k_end") - pl.col("k_start")).alias("dur_ns")) | |
| ) | |
| return ( | |
| picked.group_by("nvtx_label") | |
| .agg(pl.col("dur_ns").sum().alias("device_time_ns")) | |
| .rename({"nvtx_label": "operator"}) | |
| ) | |
| SUMMARY_ROW_COLS = [ | |
| "operator", | |
| "count", | |
| "host_time_ns", | |
| "host_pct", | |
| "wall_time_ns", | |
| "wall_pct", | |
| "device_time_ns", | |
| "device_pct", | |
| ] | |
| def _operator_summary_single_partition( | |
| df: pl.DataFrame, | |
| kernels: pl.DataFrame, | |
| ) -> pl.DataFrame: | |
| """ | |
| Operator summary for one partition of IR NVTX rows (already scoped). No query/iteration columns. | |
| """ | |
| if df.is_empty(): | |
| return pl.DataFrame( | |
| { | |
| "operator": ["Total"], | |
| "count": [0], | |
| "host_time_ns": [0], | |
| "host_pct": [0.0], | |
| "wall_time_ns": [0], | |
| "wall_pct": [0.0], | |
| "device_time_ns": [0], | |
| "device_pct": [0.0], | |
| } | |
| ) | |
| wall_total_ns = int(df["end"].max() - df["start"].min()) | |
| wall_by_op = wall_union_ns_per_operator(df) | |
| host_agg = ( | |
| df.group_by("label_clean") | |
| .agg( | |
| (pl.col("end") - pl.col("start")).sum().alias("host_time_ns"), | |
| pl.len().alias("count"), | |
| ) | |
| .rename({"label_clean": "operator"}) | |
| ) | |
| nvtx_for_device = df.select( | |
| pl.col("start"), | |
| pl.col("end"), | |
| pl.col("globalTid"), | |
| pl.col("label_clean").alias("label"), | |
| ) | |
| dev_pl = attribute_kernel_ms_innermost(nvtx_for_device, kernels) | |
| ops = pl.concat([host_agg.select("operator"), dev_pl.select("operator")]).unique() | |
| merged = ( | |
| ops.join(host_agg, on="operator", how="left") | |
| .join(dev_pl, on="operator", how="left") | |
| .join(wall_by_op, on="operator", how="left") | |
| .with_columns( | |
| pl.col("host_time_ns").fill_null(0), | |
| pl.col("device_time_ns").fill_null(0), | |
| pl.col("count").fill_null(0), | |
| pl.col("wall_union_ns").fill_null(0), | |
| ) | |
| ) | |
| total_h = int(merged["host_time_ns"].sum()) | |
| total_d = int(merged["device_time_ns"].sum()) | |
| total_count = int(merged["count"].sum()) | |
| merged = merged.sort("operator").with_columns( | |
| pl.when(pl.lit(total_h) > 0) | |
| .then(pl.col("host_time_ns") / pl.lit(total_h) * 100.0) | |
| .otherwise(0.0) | |
| .alias("host_pct"), | |
| pl.when(pl.lit(wall_total_ns) > 0) | |
| .then(pl.col("wall_union_ns") / pl.lit(wall_total_ns) * 100.0) | |
| .otherwise(0.0) | |
| .alias("wall_pct"), | |
| pl.when(pl.lit(total_d) > 0) | |
| .then(pl.col("device_time_ns") / pl.lit(total_d) * 100.0) | |
| .otherwise(0.0) | |
| .alias("device_pct"), | |
| ) | |
| merged = ( | |
| merged.select( | |
| pl.col("operator"), | |
| pl.col("count"), | |
| pl.col("host_time_ns"), | |
| pl.col("host_pct"), | |
| pl.col("wall_union_ns").alias("wall_time_ns"), | |
| pl.col("wall_pct"), | |
| pl.col("device_time_ns"), | |
| pl.col("device_pct"), | |
| ) | |
| .with_columns(pl.col("count").cast(pl.Int64)) | |
| ) | |
| total_row = pl.DataFrame( | |
| { | |
| "operator": ["Total"], | |
| "count": [total_count], | |
| "host_time_ns": [total_h], | |
| "host_pct": [100.0 if total_h else 0.0], | |
| "wall_time_ns": [wall_total_ns], | |
| "wall_pct": [100.0 if wall_total_ns else 0.0], | |
| "device_time_ns": [total_d], | |
| "device_pct": [100.0 if total_d else 0.0], | |
| } | |
| ).select(SUMMARY_ROW_COLS) | |
| return pl.concat([merged, total_row]) | |
| def compute_operator_summary( | |
| nvtx: pl.DataFrame, | |
| kernels: pl.DataFrame, | |
| domain_ids: list[int], | |
| event_type_ids: list[int], | |
| qi: pl.DataFrame, | |
| query_filter: int | None, | |
| iteration_filter: int | None, | |
| *, | |
| all_types: bool, | |
| exclude: frozenset[str], | |
| ) -> pl.DataFrame: | |
| """Operator summary grouped by ``query`` and ``iteration`` from ``Query N - Iteration M`` NVTX.""" | |
| df = nvtx.filter( | |
| pl.col("domainId").is_in(domain_ids) | |
| & pl.col("eventType").is_in(event_type_ids) | |
| & pl.col("label").is_not_null() | |
| & (pl.col("end") > pl.col("start")) | |
| ).with_columns( | |
| pl.col("label").cast(pl.Utf8).str.strip_chars().alias("label_clean"), | |
| ) | |
| df = df.filter(pl.col("label_clean").str.len_chars() > 0) | |
| if not all_types: | |
| df = df.filter(~pl.col("label_clean").is_in(list(exclude))) | |
| df = df.filter(~pl.col("label_clean").str.contains(QUERY_ITER_REGEX)) | |
| qi_use = filter_qi_dataframe(qi, query_filter, iteration_filter) | |
| summary_out_cols = ["query", "iteration", *SUMMARY_ROW_COLS] | |
| if qi.is_empty(): | |
| part = _operator_summary_single_partition(df, kernels) | |
| return part.with_columns( | |
| pl.lit(None).cast(pl.Int64).alias("query"), | |
| pl.lit(None).cast(pl.Int64).alias("iteration"), | |
| ).select(summary_out_cols) | |
| df_qi = assign_query_iteration_to_events(df, qi_use) | |
| df_qi = df_qi.filter(pl.col("query").is_not_null() & pl.col("iteration").is_not_null()) | |
| if df_qi.is_empty(): | |
| return pl.DataFrame( | |
| schema={ | |
| "query": pl.Int64, | |
| "iteration": pl.Int64, | |
| "operator": pl.Utf8, | |
| "count": pl.Int64, | |
| "host_time_ns": pl.Int64, | |
| "host_pct": pl.Float64, | |
| "wall_time_ns": pl.Int64, | |
| "wall_pct": pl.Float64, | |
| "device_time_ns": pl.Int64, | |
| "device_pct": pl.Float64, | |
| } | |
| ) | |
| def _add_qi_cols(g: pl.DataFrame) -> pl.DataFrame: | |
| qv = int(g["query"][0]) | |
| iv = int(g["iteration"][0]) | |
| sub = g.drop("query", "iteration") | |
| block = _operator_summary_single_partition(sub, kernels) | |
| return block.with_columns( | |
| pl.lit(qv).alias("query"), | |
| pl.lit(iv).alias("iteration"), | |
| ).select(summary_out_cols) | |
| out = df_qi.group_by(["query", "iteration"], maintain_order=False).map_groups( | |
| _add_qi_cols | |
| ) | |
| # Keep ``Total`` last within each (query, iteration); sorting only by ``operator`` | |
| # would place it alphabetically (e.g. before ``Union``). | |
| return ( | |
| out.with_columns((pl.col("operator") == "Total").cast(pl.Int8).alias("_total_last")) | |
| .sort(["query", "iteration", "_total_last", "operator"]) | |
| .drop("_total_last") | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # Analytics: I/O (kvikio bytes + Scan) | |
| # ----------------------------------------------------------------------------- | |
| def row_payload_bytes( | |
| row: dict[str, Any], | |
| schema: NsysSchema, | |
| nvtx_cols: frozenset[str], | |
| ) -> int | None: | |
| if schema.nvtx_uint64_col and "payload_raw" in row: | |
| raw = row["payload_raw"] | |
| if raw is not None: | |
| return int(raw) & 0xFFFFFFFFFFFFFFFF | |
| raw = row.get("payload_raw") | |
| pt = row.get("payload_type") | |
| pti = int(pt) if pt is not None else None | |
| return decode_nvtx_uint64(raw, pti) | |
| def add_payload_bytes_column(df: pl.DataFrame, schema: NsysSchema) -> pl.DataFrame: | |
| """Add ``payload_bytes`` (decoded byte count) per NVTX row; empty ``df`` unchanged shape.""" | |
| if df.is_empty(): | |
| return df.with_columns(pl.lit(None).cast(pl.Int64).alias("payload_bytes")) | |
| if schema.nvtx_uint64_col: | |
| # SQLite / Polars may load uint64Value as Float64; mask in UInt64 space, then Int64. | |
| pr = pl.col("payload_raw") | |
| mask = pl.lit(0xFFFFFFFFFFFFFFFF, dtype=pl.UInt64) | |
| return df.with_columns( | |
| pl.when(pr.is_not_null()) | |
| .then((pr.cast(pl.UInt64, strict=False) & mask).cast(pl.Int64)) | |
| .otherwise(pl.lit(None, dtype=pl.Int64)) | |
| .alias("payload_bytes") | |
| ) | |
| return df.with_columns( | |
| pl.struct(["payload_raw", "payload_type"]) | |
| .map_elements( | |
| lambda r: decode_nvtx_uint64(r["payload_raw"], r["payload_type"]), | |
| return_dtype=pl.Int64, | |
| ) | |
| .alias("payload_bytes") | |
| ) | |
| def compute_io_summary( | |
| nvtx: pl.DataFrame, | |
| kernels: pl.DataFrame, | |
| kvikio_domain_ids: list[int], | |
| cudf_domain_ids: list[int], | |
| event_type_ids: list[int], | |
| windows: pl.DataFrame | None, | |
| nvtx_sqlite_cols: NsysSchema, | |
| ) -> dict[str, Any]: | |
| kset = set(kvikio_domain_ids) | |
| cset = set(cudf_domain_ids) | |
| et_list = list(event_type_ids) | |
| base = nvtx.filter( | |
| pl.col("eventType").is_in(et_list) & (pl.col("end") > pl.col("start")) | |
| ) | |
| base = filter_intervals_overlap_windows(base, windows) | |
| label_txt = ( | |
| pl.col("label").cast(pl.Utf8).fill_null("").str.strip_chars().alias("_label_txt") | |
| ) | |
| base = base.with_columns(label_txt) | |
| kv = base.filter( | |
| pl.col("domainId").is_in(list(kset)) | |
| & pl.col("_label_txt").str.contains(KVIKIO_READ_NAME_SUBSTR, literal=True) | |
| ) | |
| kv = add_payload_bytes_column(kv, nvtx_sqlite_cols) | |
| total_bytes = ( | |
| 0 | |
| if kv.is_empty() | |
| else int(kv.select(pl.col("payload_bytes").sum()).item() or 0) | |
| ) | |
| scan_df = base.filter( | |
| pl.col("domainId").is_in(list(cset)) & (pl.col("_label_txt") == "Scan") | |
| ).select( | |
| pl.col("start").cast(pl.Int64), | |
| pl.col("end").cast(pl.Int64), | |
| pl.col("globalTid").cast(pl.Int64), | |
| pl.col("_label_txt").alias("label"), | |
| ) | |
| scan_count = scan_df.height | |
| if scan_df.is_empty(): | |
| sum_scan_host_ns = 0 | |
| scan_union_wall_ns = 0 | |
| else: | |
| sum_scan_host_ns = int( | |
| scan_df.select((pl.col("end") - pl.col("start")).sum()).item() | |
| ) | |
| scan_union_wall_ns = merge_intervals_union_ns(scan_df.select("start", "end")) | |
| dev_by_op = attribute_kernel_ms_innermost(scan_df, kernels) | |
| scan_row = dev_by_op.filter(pl.col("operator") == "Scan") | |
| dev_scan_ns = int(scan_row["device_time_ns"][0]) if scan_row.height > 0 else 0 | |
| # Throughput: bytes / denominator (same total_bytes for all) | |
| def tp(denom_ns: int) -> float | None: | |
| if denom_ns <= 0: | |
| return None | |
| return float(total_bytes) / float(denom_ns) | |
| return { | |
| "scan_count": scan_count, | |
| "total_bytes": total_bytes, | |
| "sum_scan_host_ns": sum_scan_host_ns, | |
| "scan_union_wall_ns": scan_union_wall_ns, | |
| "sum_scan_device_ns": dev_scan_ns, | |
| "throughput_bytes_per_ns_sum_host": tp(sum_scan_host_ns), | |
| "throughput_bytes_per_ns_union_wall": tp(scan_union_wall_ns), | |
| "throughput_bytes_per_ns_sum_device": tp(dev_scan_ns), | |
| } | |
| # ----------------------------------------------------------------------------- | |
| # Human-readable formatting (table output only) | |
| # ----------------------------------------------------------------------------- | |
| def format_ns(ns: float) -> str: | |
| if ns >= 1e9: | |
| return f"{ns / 1e9:.3f} s" | |
| if ns >= 1e6: | |
| return f"{ns / 1e6:.3f} ms" | |
| if ns >= 1e3: | |
| return f"{ns / 1e3:.3f} µs" | |
| return f"{ns:.0f} ns" | |
| def format_bytes(n: int) -> str: | |
| if n >= 1 << 30: | |
| return f"{n / (1 << 30):.3f} GiB" | |
| if n >= 1 << 20: | |
| return f"{n / (1 << 20):.3f} MiB" | |
| if n >= 1 << 10: | |
| return f"{n / (1 << 10):.3f} KiB" | |
| return str(n) | |
| def format_throughput_table(bytes_count: int, denom_ns: int) -> str: | |
| if denom_ns <= 0: | |
| return "n/a" | |
| bytes_per_s = bytes_count / denom_ns * 1e9 | |
| gib_s = bytes_per_s / (1024**3) | |
| if gib_s >= 1.0: | |
| return f"{gib_s:.4f} GiB/s" | |
| mib_s = bytes_per_s / (1024**2) | |
| return f"{mib_s:.4f} MiB/s" | |
| # ----------------------------------------------------------------------------- | |
| # Render | |
| # ----------------------------------------------------------------------------- | |
| def _fmt_qi_cell(v: object) -> str: | |
| if v is None: | |
| return "—" | |
| return str(int(v)) | |
| def render_summary_table(console: Console, df: pl.DataFrame) -> None: | |
| t = Table(show_header=True, header_style="bold") | |
| t.add_column("Query") | |
| t.add_column("Iter") | |
| t.add_column("Operator") | |
| t.add_column("Count", justify="right") | |
| t.add_column("Host Time", justify="right") | |
| t.add_column("Host %", justify="right") | |
| t.add_column("Wall Time", justify="right") | |
| t.add_column("Wall %", justify="right") | |
| t.add_column("Device Time", justify="right") | |
| t.add_column("Device %", justify="right") | |
| for row in df.iter_rows(named=True): | |
| op = row["operator"] | |
| sty = "bold" if op == "Total" else None | |
| t.add_row( | |
| _fmt_qi_cell(row.get("query")), | |
| _fmt_qi_cell(row.get("iteration")), | |
| op, | |
| str(int(row["count"])), | |
| format_ns(float(row["host_time_ns"])), | |
| f"{row['host_pct']:.1f}%", | |
| format_ns(float(row["wall_time_ns"])), | |
| f"{row['wall_pct']:.1f}%", | |
| format_ns(float(row["device_time_ns"])), | |
| f"{row['device_pct']:.1f}%", | |
| style=sty, | |
| ) | |
| console.print(t) | |
| def render_io_table(console: Console, d: dict[str, Any]) -> None: | |
| console.print("[bold]Scan[/bold]") | |
| console.print(f" Count: {d['scan_count']}") | |
| console.print(f" Bytes: {d['total_bytes']} ({format_bytes(int(d['total_bytes']))})") | |
| console.print() | |
| console.print("[bold]Durations[/bold]") | |
| console.print( | |
| f" Sum host durations: {format_ns(float(d['sum_scan_host_ns']))}" | |
| ) | |
| console.print( | |
| f" Union host durations (wall time): {format_ns(float(d['scan_union_wall_ns']))}" | |
| ) | |
| # console.print( | |
| # f" Sum of kernel time attributed to Scan: {format_ns(float(d['sum_scan_device_ns']))}" | |
| # ) | |
| console.print() | |
| console.print("[bold]Throughput (total_bytes / denominator)[/bold]") | |
| tb = int(d["total_bytes"]) | |
| for label, key in ( | |
| ("Host time", "sum_scan_host_ns"), | |
| ("Wall time", "scan_union_wall_ns"), | |
| # ("Σ Scan device (kernels)", "sum_scan_device_ns"), | |
| ): | |
| den = int(d[key]) | |
| tp = (tb / den) if den > 0 else None | |
| tp_s = format_throughput_table(tb, den) if tp is not None else "n/a" | |
| console.print(f" {label}: {tp_s}") | |
| def df_to_json_records(df: pl.DataFrame) -> str: | |
| return json.dumps(df.to_dicts(), indent=2) | |
| def print_csv(df: pl.DataFrame) -> None: | |
| print(df.write_csv()) | |
| # ----------------------------------------------------------------------------- | |
| # Main | |
| # ----------------------------------------------------------------------------- | |
| def add_export_args(p: argparse.ArgumentParser) -> None: | |
| p.add_argument( | |
| "report", | |
| type=Path, | |
| help="Path to .nsys-rep", | |
| ) | |
| p.add_argument( | |
| "--sqlite", | |
| type=Path, | |
| default=None, | |
| help="SQLite path (default: <report>.sqlite)", | |
| ) | |
| p.add_argument( | |
| "--no-export", | |
| action="store_true", | |
| help="Do not run nsys export; require existing SQLite", | |
| ) | |
| p.add_argument( | |
| "--force-export", | |
| action="store_true", | |
| help="Re-run nsys export even if SQLite exists", | |
| ) | |
| p.add_argument( | |
| "--domain", | |
| default="cudf_polars", | |
| help="NVTX domain for IR / Query ranges (default: cudf_polars)", | |
| ) | |
| p.add_argument( | |
| "--kvikio-domain", | |
| default="libkvikio", | |
| help="NVTX domain for kvikio reads (default: libkvikio)", | |
| ) | |
| p.add_argument( | |
| "--query", | |
| type=int, | |
| default=None, | |
| help=( | |
| "Restrict summary to this query index N (only these Query N - Iteration M " | |
| "windows are used to label events; default: all queries)" | |
| ), | |
| ) | |
| p.add_argument( | |
| "--iteration", | |
| type=int, | |
| default=None, | |
| help=( | |
| "Restrict summary to this iteration M (only these Query N - Iteration M " | |
| "windows are used to label events; default: all iterations)" | |
| ), | |
| ) | |
| p.add_argument( | |
| "--format", | |
| choices=("table", "json", "csv"), | |
| default="table", | |
| help="Output format (table uses human units; json/csv use raw ns and bytes)", | |
| ) | |
| p.add_argument( | |
| "--parquet-cache", | |
| dest="parquet_cache", | |
| action="store_true", | |
| default=True, | |
| help="Write parquet extract cache (default: on)", | |
| ) | |
| p.add_argument( | |
| "--no-parquet-cache", | |
| dest="parquet_cache", | |
| action="store_false", | |
| help="Do not write parquet cache", | |
| ) | |
| p.add_argument( | |
| "--use-cache", | |
| dest="use_cache", | |
| action="store_true", | |
| default=True, | |
| help="Load from parquet cache when fresh (default: on)", | |
| ) | |
| p.add_argument( | |
| "--no-use-cache", | |
| dest="use_cache", | |
| action="store_false", | |
| help="Always read SQLite", | |
| ) | |
| SUMMARY_HELP_EPILOG = """ | |
| Time columns (summary): | |
| Host Sum of NVTX range lengths (end - start) for each range tagged with that | |
| operator on CPU threads. Parent and child ranges can overlap in time, so | |
| sums can exceed wall-clock span. The Host percent column is the share of | |
| that summed time within the Query/Iteration group. | |
| Wall Union of those intervals on one global timeline (overlaps merged). The Wall | |
| percent column is union length divided by the span from the earliest IR | |
| start to the latest IR end in the same group (calendar coverage). | |
| Device Sum of GPU kernel durations attributed to this operator: each kernel is | |
| assigned to the innermost active NVTX range on the launching CPU thread at | |
| cuda launch. The Device percent column is the share of summed kernel time | |
| in the group. | |
| """ | |
| def cmd_summary(args: argparse.Namespace) -> int: | |
| p = argparse.ArgumentParser( | |
| description=( | |
| "Per-operator table grouped by Query / Iteration from NVTX. " | |
| "Shows host (CPU NVTX), wall (union) coverage, and device (GPU) time." | |
| ), | |
| epilog=SUMMARY_HELP_EPILOG, | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| ) | |
| p.add_argument( | |
| "--all-types", | |
| action="store_true", | |
| help="Include every NVTX label in the domain (not only known IR names)", | |
| ) | |
| p.add_argument( | |
| "--exclude", | |
| default=",".join(sorted(DEFAULT_EXCLUDE_LABELS)), | |
| help="Comma-separated labels to skip (default: ConvertIR,ExecuteIR)", | |
| ) | |
| add_export_args(p) | |
| ns = p.parse_args(args._remainder) | |
| nsys_rep = ns.report.expanduser().resolve() | |
| if not nsys_rep.is_file(): | |
| print(f"Report not found: {nsys_rep}", file=sys.stderr) | |
| return 1 | |
| sqlite_path = (ns.sqlite or default_sqlite_path(nsys_rep)).expanduser().resolve() | |
| if not ns.no_export: | |
| ensure_nsys_available() | |
| try: | |
| run_nsys_export(nsys_rep, sqlite_path, force=ns.force_export) | |
| except subprocess.CalledProcessError as e: | |
| print(f"nsys export failed (exit {e.returncode})", file=sys.stderr) | |
| return e.returncode or 1 | |
| elif not sqlite_path.is_file(): | |
| print(f"SQLite not found (--no-export): {sqlite_path}", file=sys.stderr) | |
| return 1 | |
| segment = cache_key_segment(ns.query, ns.iteration) | |
| ppaths = parquet_cache_paths(nsys_rep, segment) | |
| conn = sqlite3.connect(f"file:{sqlite_path}?mode=ro", uri=True) | |
| try: | |
| et_ids = nvtx_range_event_type_ids(conn) | |
| if not et_ids: | |
| print("No NVTX range event types in export.", file=sys.stderr) | |
| return 1 | |
| cached = maybe_read_parquet_cache(ppaths, sqlite_path, use_cache=ns.use_cache) | |
| if cached is not None: | |
| nvtx, enum_t, kr = cached | |
| else: | |
| nvtx, enum_t, kr = load_extracted_tables(conn, et_ids) | |
| if ns.parquet_cache: | |
| write_parquet_cache(ppaths, nvtx, enum_t, kr) | |
| domain_ids = domain_ids_for_name(conn, ns.domain) | |
| if not domain_ids: | |
| print( | |
| f"No NvtxDomainCreate for domain {ns.domain!r}.", | |
| file=sys.stderr, | |
| ) | |
| return 1 | |
| qi = load_query_iteration_windows(nvtx, domain_ids, et_ids) | |
| qi_filtered = filter_qi_dataframe(qi, ns.query, ns.iteration) | |
| if (ns.query is not None or ns.iteration is not None) and qi_filtered.is_empty(): | |
| print( | |
| "No NVTX ranges matched --query/--iteration.", | |
| file=sys.stderr, | |
| ) | |
| return 1 | |
| exclude = frozenset( | |
| x.strip() for x in ns.exclude.split(",") if x.strip() | |
| ) | |
| df = compute_operator_summary( | |
| nvtx, | |
| kr, | |
| domain_ids, | |
| et_ids, | |
| qi, | |
| query_filter=ns.query, | |
| iteration_filter=ns.iteration, | |
| all_types=ns.all_types, | |
| exclude=exclude, | |
| ) | |
| finally: | |
| conn.close() | |
| console = Console() | |
| if ns.format == "table": | |
| render_summary_table(console, df) | |
| elif ns.format == "json": | |
| print(df_to_json_records(df)) | |
| else: | |
| print_csv(df) | |
| return 0 | |
| def cmd_io(args: argparse.Namespace) -> int: | |
| p = argparse.ArgumentParser() | |
| add_export_args(p) | |
| ns = p.parse_args(args._remainder) | |
| nsys_rep = ns.report.expanduser().resolve() | |
| if not nsys_rep.is_file(): | |
| print(f"Report not found: {nsys_rep}", file=sys.stderr) | |
| return 1 | |
| sqlite_path = (ns.sqlite or default_sqlite_path(nsys_rep)).expanduser().resolve() | |
| if not ns.no_export: | |
| ensure_nsys_available() | |
| try: | |
| run_nsys_export(nsys_rep, sqlite_path, force=ns.force_export) | |
| except subprocess.CalledProcessError as e: | |
| print(f"nsys export failed (exit {e.returncode})", file=sys.stderr) | |
| return e.returncode or 1 | |
| elif not sqlite_path.is_file(): | |
| print(f"SQLite not found (--no-export): {sqlite_path}", file=sys.stderr) | |
| return 1 | |
| segment = cache_key_segment(ns.query, ns.iteration) | |
| ppaths = parquet_cache_paths(nsys_rep, segment) | |
| conn = sqlite3.connect(f"file:{sqlite_path}?mode=ro", uri=True) | |
| try: | |
| cols = sqlite_table_columns(conn, "NVTX_EVENTS") | |
| nvtx_schema = detect_nvtx_payload_schema(cols) | |
| et_ids = nvtx_range_event_type_ids(conn) | |
| if not et_ids: | |
| print("No NVTX range event types in export.", file=sys.stderr) | |
| return 1 | |
| cached = maybe_read_parquet_cache(ppaths, sqlite_path, use_cache=ns.use_cache) | |
| if cached is not None: | |
| nvtx, _enum_t, kr = cached | |
| else: | |
| nvtx, _enum_t, kr = load_extracted_tables(conn, et_ids) | |
| if ns.parquet_cache: | |
| write_parquet_cache(ppaths, nvtx, _enum_t, kr) | |
| cudf_ids = domain_ids_for_name(conn, ns.domain) | |
| kvikio_ids = domain_ids_for_name(conn, ns.kvikio_domain) | |
| if not cudf_ids: | |
| print(f"No domain {ns.domain!r}.", file=sys.stderr) | |
| return 1 | |
| qi = load_query_iteration_windows(nvtx, cudf_ids, et_ids) | |
| windows, _ = filter_windows_for_cli(qi, ns.query, ns.iteration) | |
| if ( | |
| (ns.query is not None or ns.iteration is not None) | |
| and windows is not None | |
| and windows.is_empty() | |
| ): | |
| print( | |
| "No NVTX ranges matched --query/--iteration.", | |
| file=sys.stderr, | |
| ) | |
| return 1 | |
| summary = compute_io_summary( | |
| nvtx, | |
| kr, | |
| kvikio_ids, | |
| cudf_ids, | |
| et_ids, | |
| windows, | |
| nvtx_schema, | |
| ) | |
| finally: | |
| conn.close() | |
| console = Console() | |
| if ns.format == "table": | |
| render_io_table(console, summary) | |
| elif ns.format == "json": | |
| print(json.dumps(summary, indent=2)) | |
| else: | |
| import io | |
| buf = io.StringIO() | |
| pl.DataFrame([summary]).write_csv(buf) | |
| print(buf.getvalue()) | |
| return 0 | |
| def main(argv: list[str] | None = None) -> int: | |
| argv = argv if argv is not None else sys.argv[1:] | |
| if not argv or argv[0] in ("-h", "--help"): | |
| parser = argparse.ArgumentParser( | |
| description=( | |
| "cudf-polars Nsight Systems report analyzer (cpr). " | |
| "Run 'summary --help' for Host vs Wall vs Device time." | |
| ), | |
| ) | |
| sub = parser.add_subparsers(dest="command", required=True) | |
| sp = sub.add_parser( | |
| "summary", | |
| help="Operator summary by Query/Iteration (host, wall, device time)", | |
| ) | |
| sp.add_argument("remainder", nargs=argparse.REMAINDER, default=[], help=argparse.SUPPRESS) | |
| io_p = sub.add_parser("io", help="I/O summary (Scan + kvikio reads)") | |
| io_p.add_argument("remainder", nargs=argparse.REMAINDER, default=[], help=argparse.SUPPRESS) | |
| sp.set_defaults(func=lambda a: cmd_summary(a)) | |
| io_p.set_defaults(func=lambda a: cmd_io(a)) | |
| parser.parse_args(argv if argv else ["-h"]) | |
| return 0 | |
| cmd = argv[0] | |
| rest = argv[1:] | |
| if cmd == "summary": | |
| ns = argparse.Namespace(_remainder=rest) | |
| return cmd_summary(ns) | |
| if cmd == "io": | |
| ns = argparse.Namespace(_remainder=rest) | |
| return cmd_io(ns) | |
| print(f"Unknown command: {cmd}", file=sys.stderr) | |
| return 2 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment