|
"""Stress-test marimo IPC across Python versions and schema skew. |
|
|
|
Builds a wheel from the current tree, then for each Python version in the |
|
matrix spawns a kernel via |
|
|
|
uv run --isolated --no-project --python=X.Y --with=<wheel> |
|
-m marimo._ipc.launch_kernel |
|
|
|
and runs three end-to-end checks against it: |
|
|
|
1. execute - plain ExecuteCellsCommand round-trips (control + stream). |
|
2. forward_command - driver sends an ExecuteCellsCommand-shaped msgpack |
|
payload with an EXTRA field; the kernel's real decoder |
|
must ignore it and still execute the cell. |
|
3. broadcast - cell code broadcasts two Notification variants: |
|
(a) a registered tag (CompletedRunNotification), |
|
(b) a known tag with an EXTRA field (CellNotification |
|
subclass on tag="cell-op"). |
|
Both must decode cleanly on the host. |
|
|
|
A brand-new Notification tag (closed NotificationMessage union) is recorded |
|
separately as a known forward-compat hazard, not a pass/fail. |
|
|
|
Completion is event-driven: each exec uses a unique cell_id, and a scenario |
|
is done when we've seen a cell-op for that cell_id followed by completed-run. |
|
That isolates scenarios from each other without any fixed sleeps. |
|
|
|
Usage (from repo root): |
|
|
|
uv run --extra sandbox scripts/stress_test_ipc.py |
|
uv run --extra sandbox scripts/stress_test_ipc.py --python 3.11 |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import argparse |
|
import pathlib |
|
import queue |
|
import subprocess |
|
import sys |
|
import tempfile |
|
import time |
|
from dataclasses import dataclass, field |
|
from typing import Any, Callable |
|
|
|
import msgspec |
|
|
|
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent |
|
sys.path.insert(0, str(REPO_ROOT)) |
|
|
|
from marimo._ast.app_config import _AppConfig # noqa: E402 |
|
from marimo._ast.cell import CellConfig # noqa: E402 |
|
from marimo._config.config import DEFAULT_CONFIG # noqa: E402 |
|
from marimo._config.settings import GLOBAL_SETTINGS # noqa: E402 |
|
from marimo._ipc.queue_manager import QueueManager # noqa: E402 |
|
from marimo._ipc.types import KernelArgs # noqa: E402 |
|
from marimo._messaging.serde import deserialize_kernel_message # noqa: E402 |
|
from marimo._runtime.commands import ( # noqa: E402 |
|
AppMetadata, |
|
ExecuteCellsCommand, |
|
) |
|
from marimo._types.ids import CellId_t # noqa: E402 |
|
|
|
DEFAULT_PYTHON_VERSIONS = ["3.10", "3.11", "3.12", "3.13", "3.14"] |
|
TIMEOUT = 10.0 |
|
|
|
UNKNOWN_TAG = "stress-foo-v1" |
|
|
|
|
|
@dataclass |
|
class Result: |
|
python_version: str |
|
scenario: str |
|
passed: bool |
|
detail: str = "" |
|
|
|
|
|
@dataclass |
|
class Session: |
|
proc: subprocess.Popen[bytes] |
|
qm: QueueManager |
|
unknown_tag_bytes_seen: bool = field(default=False) |
|
host_decode_errors: int = field(default=0) |
|
|
|
def close(self) -> None: |
|
self.proc.terminate() |
|
try: |
|
self.proc.wait(timeout=3.0) |
|
except subprocess.TimeoutExpired: |
|
self.proc.kill() |
|
self.proc.wait(timeout=3.0) |
|
self.qm.close_queues() |
|
|
|
|
|
def build_wheel() -> pathlib.Path: |
|
out_dir = pathlib.Path(tempfile.mkdtemp(prefix="marimo-stress-")) |
|
print(f"[build] uv build --wheel --out-dir {out_dir}") |
|
subprocess.run( |
|
["uv", "build", "--wheel", "--out-dir", str(out_dir)], |
|
check=True, |
|
cwd=REPO_ROOT, |
|
) |
|
wheel = next(out_dir.glob("marimo-*.whl")) |
|
print(f"[build] built {wheel.name}") |
|
return wheel |
|
|
|
|
|
def spawn(py: str, wheel: pathlib.Path, cell_ids: list[CellId_t]) -> Session | None: |
|
qm, connection_info = QueueManager.create() |
|
args = KernelArgs( |
|
connection_info=connection_info, |
|
profile_path=None, |
|
configs={cid: CellConfig() for cid in cell_ids}, |
|
user_config=DEFAULT_CONFIG, |
|
log_level=GLOBAL_SETTINGS.LOG_LEVEL, |
|
app_metadata=AppMetadata( |
|
query_params={}, cli_args={}, app_config=_AppConfig() |
|
), |
|
) |
|
proc = subprocess.Popen( |
|
[ |
|
"uv", "run", "--isolated", "--no-project", |
|
f"--python={py}", f"--with={wheel}", |
|
"python", "-m", "marimo._ipc.launch_kernel", |
|
], |
|
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, |
|
) |
|
assert proc.stdin and proc.stdout and proc.stderr |
|
proc.stdin.write(args.encode_json()) |
|
proc.stdin.flush() |
|
proc.stdin.close() |
|
|
|
ready = proc.stdout.readline().decode(errors="replace").strip() |
|
if ready != "KERNEL_READY": |
|
stderr = proc.stderr.read().decode(errors="replace") |
|
qm.close_queues() |
|
if "No interpreter found" in stderr or "No interpreters found" in stderr: |
|
return None |
|
raise RuntimeError( |
|
f"handshake failed (exit={proc.poll()}, py={py}): {stderr[-1500:]}" |
|
) |
|
return Session(proc=proc, qm=qm) |
|
|
|
|
|
GRACE = 0.3 # seconds to keep draining after completed-run |
|
|
|
|
|
def await_run(session: Session, cell_id: str) -> list[Any]: |
|
"""Block until the run for `cell_id` finishes, plus a short grace period. |
|
|
|
Two-phase: look for a cell-op matching `cell_id`, then the next |
|
completed-run. Unique cell_ids per scenario mean stale completed-run |
|
ops from prior scenarios can't satisfy phase 2. |
|
|
|
The grace period catches late messages: in-cell `broadcast_notification` |
|
calls race against the kernel's natural completed-run, so a broadcast's |
|
bytes may arrive AFTER completed-run. |
|
|
|
Side effect: populates session.unknown_tag_bytes_seen and |
|
session.host_decode_errors, which the unknown-tag check reports on. |
|
""" |
|
deadline = time.time() + TIMEOUT |
|
collected: list[Any] = [] |
|
saw_our_cell_op = False |
|
grace_until: float | None = None |
|
while True: |
|
now = time.time() |
|
if now > deadline: |
|
break |
|
if grace_until is not None and now > grace_until: |
|
break |
|
try: |
|
raw = session.qm.stream_queue.get(timeout=0.05) |
|
except queue.Empty: |
|
continue |
|
assert isinstance(raw, bytes) |
|
if UNKNOWN_TAG.encode() in raw: |
|
session.unknown_tag_bytes_seen = True |
|
try: |
|
note = deserialize_kernel_message(raw) |
|
except msgspec.DecodeError: |
|
session.host_decode_errors += 1 |
|
continue |
|
collected.append(note) |
|
name = getattr(note, "name", None) |
|
if ( |
|
not saw_our_cell_op |
|
and name == "cell-op" |
|
and getattr(note, "cell_id", None) == cell_id |
|
): |
|
saw_our_cell_op = True |
|
elif ( |
|
saw_our_cell_op and name == "completed-run" and grace_until is None |
|
): |
|
grace_until = time.time() + GRACE |
|
if not saw_our_cell_op: |
|
raise TimeoutError( |
|
f"run for {cell_id!r} did not start within {TIMEOUT}s; " |
|
f"names: {[getattr(n, 'name', '?') for n in collected]}" |
|
) |
|
return collected |
|
|
|
|
|
# ---- scenarios ------------------------------------------------------------ |
|
|
|
|
|
def scenario_execute(session: Session) -> None: |
|
"""Happy path: real typed ExecuteCellsCommand round-trips.""" |
|
session.qm.control_queue.put( |
|
ExecuteCellsCommand(cell_ids=[CellId_t("exec-cell")], codes=["x = 42"]) |
|
) |
|
batch = await_run(session, "exec-cell") |
|
assert any(getattr(n, "name", None) == "cell-op" for n in batch) |
|
|
|
|
|
def scenario_forward_command(session: Session) -> None: |
|
"""Driver sends a command struct with an EXTRA field; kernel ignores it.""" |
|
|
|
# Module-level class (not a closure) so msgspec encoding has no issues. |
|
class ExecuteCellsCommandV2( |
|
msgspec.Struct, rename="camel", tag_field="type", tag="execute-cells" |
|
): |
|
cell_ids: list[str] |
|
codes: list[str] |
|
new_field_added_later: str = "future" |
|
|
|
cell_id = "forward-cmd-cell" |
|
session.qm.control_queue.put( |
|
ExecuteCellsCommandV2(cell_ids=[cell_id], codes=["y = 7"]) # type: ignore[arg-type] |
|
) |
|
await_run(session, cell_id) |
|
|
|
|
|
def scenario_broadcast(session: Session) -> dict[str, Any]: |
|
"""One cell broadcasts three Notification variants up the stream: |
|
|
|
(a) a registered tag + EXTRA field -> host must decode cleanly |
|
(forward-compat on a known op). |
|
(b) a brand-new tag -> host decoder rejects it |
|
(forward-compat HAZARD -- the real finding). |
|
|
|
Combining both into a single cell avoids marimo's MultipleDefinitionError |
|
from re-importing `Notification` / `ClassVar` across cells. |
|
|
|
The forward-compat case (a) is a hard assertion. The unknown-tag case (b) |
|
is reported as info so we can see whether the bytes made it to the wire |
|
and whether the host raised the expected DecodeError. |
|
""" |
|
cell_id = "broadcast-cell" |
|
v2_target = "v2-broadcast-target" |
|
before_bytes = session.unknown_tag_bytes_seen |
|
before_errs = session.host_decode_errors |
|
code = f"""\ |
|
from typing import ClassVar |
|
from marimo._messaging.notification import Notification |
|
from marimo._messaging.notification_utils import broadcast_notification |
|
|
|
class CellNotificationV2(Notification, tag="cell-op"): |
|
name: ClassVar[str] = "cell-op" |
|
cell_id: str |
|
new_field_added_later: str = "future" |
|
|
|
class StressFoo(Notification, tag={UNKNOWN_TAG!r}): |
|
name: ClassVar[str] = {UNKNOWN_TAG!r} |
|
payload: str = "hi" |
|
|
|
broadcast_notification(CellNotificationV2(cell_id={v2_target!r})) |
|
broadcast_notification(StressFoo()) |
|
""" |
|
session.qm.control_queue.put( |
|
ExecuteCellsCommand(cell_ids=[CellId_t(cell_id)], codes=[code]) |
|
) |
|
batch = await_run(session, cell_id) |
|
saw_v2 = any( |
|
getattr(n, "name", None) == "cell-op" |
|
and getattr(n, "cell_id", None) == v2_target |
|
for n in batch |
|
) |
|
assert saw_v2, ( |
|
f"V2 cell-op (extra field on known tag) missing from batch: " |
|
f"names={[getattr(n, 'name', '?') for n in batch]}" |
|
) |
|
return { |
|
"unknown_tag_bytes_arrived": ( |
|
session.unknown_tag_bytes_seen and not before_bytes |
|
), |
|
"unknown_tag_decode_errors": ( |
|
session.host_decode_errors - before_errs |
|
), |
|
} |
|
|
|
|
|
# ---- driver --------------------------------------------------------------- |
|
|
|
|
|
CELL_IDS = ["warmup", "exec-cell", "forward-cmd-cell", "broadcast-cell"] |
|
|
|
|
|
def run_for_python(py: str, wheel: pathlib.Path) -> list[Result]: |
|
print(f"\n=== python {py} ===") |
|
cell_ids = [CellId_t(cid) for cid in CELL_IDS] |
|
try: |
|
session = spawn(py, wheel, cell_ids) |
|
except Exception as e: |
|
print(f" [FAIL] spawn: {e}") |
|
return [Result(py, "spawn", False, str(e))] |
|
if session is None: |
|
print(f" [skip] {py} not available") |
|
return [Result(py, "spawn", True, "interpreter unavailable")] |
|
|
|
results: list[Result] = [] |
|
|
|
def run(name: str, fn: Callable[[Session], Any]) -> Any: |
|
try: |
|
out = fn(session) |
|
except Exception as e: |
|
print(f" [FAIL] {name}: {e}") |
|
results.append(Result(py, name, False, str(e))) |
|
return None |
|
print(f" [PASS] {name}") |
|
results.append(Result(py, name, True)) |
|
return out |
|
|
|
try: |
|
# Event-driven readiness probe: a round-tripped exec proves the |
|
# runtime loop is live. No fixed sleep needed after KERNEL_READY. |
|
session.qm.control_queue.put( |
|
ExecuteCellsCommand(cell_ids=[CellId_t("warmup")], codes=["1"]) |
|
) |
|
await_run(session, "warmup") |
|
print(" [PASS] warmup") |
|
results.append(Result(py, "warmup", True)) |
|
|
|
run("execute", scenario_execute) |
|
run("forward_command", scenario_forward_command) |
|
info = run("broadcast", scenario_broadcast) |
|
if info is not None: |
|
detail = ( |
|
f"unknown_tag bytes_arrived=" |
|
f"{info['unknown_tag_bytes_arrived']}, " |
|
f"decode_errors={info['unknown_tag_decode_errors']}" |
|
) |
|
print(f" [INFO] {detail}") |
|
results.append(Result(py, "unknown_tag_hazard", True, detail)) |
|
finally: |
|
session.close() |
|
return results |
|
|
|
|
|
def print_report(results: list[Result]) -> bool: |
|
print("\n=== Summary ===") |
|
by_py: dict[str, list[Result]] = {} |
|
for r in results: |
|
by_py.setdefault(r.python_version, []).append(r) |
|
ok = True |
|
for py, rs in by_py.items(): |
|
status = "PASS" if all(r.passed for r in rs) else "FAIL" |
|
print(f" {py}: {status} ({sum(r.passed for r in rs)}/{len(rs)})") |
|
for r in rs: |
|
if not r.passed: |
|
ok = False |
|
print(f" FAIL {r.scenario}: {r.detail}") |
|
elif r.detail: |
|
print(f" note {r.scenario}: {r.detail}") |
|
return ok |
|
|
|
|
|
def main() -> int: |
|
parser = argparse.ArgumentParser(description=__doc__) |
|
parser.add_argument("--python", action="append", default=None, |
|
help="Python version (repeatable). Default: 3.10-3.14.") |
|
parser.add_argument("--wheel", type=pathlib.Path, default=None, |
|
help="Use an existing wheel instead of building.") |
|
args = parser.parse_args() |
|
|
|
wheel = args.wheel or build_wheel() |
|
pys = args.python or DEFAULT_PYTHON_VERSIONS |
|
|
|
all_results: list[Result] = [] |
|
for py in pys: |
|
all_results.extend(run_for_python(py, wheel)) |
|
|
|
return 0 if print_report(all_results) else 1 |
|
|
|
|
|
if __name__ == "__main__": |
|
sys.exit(main()) |