Skip to content

Instantly share code, notes, and snippets.

@mvdbeek
Created March 26, 2026 15:05
Show Gist options
  • Select an option

  • Save mvdbeek/880b1b4929b8b0d5f2f2ee041717c6c8 to your computer and use it in GitHub Desktop.

Select an option

Save mvdbeek/880b1b4929b8b0d5f2f2ee041717c6c8 to your computer and use it in GitHub Desktop.
Test demonstrating N+1 fix for workflow download (GALAXY-MAIN-14JQ)
"""Demonstrate the N+1 fix for workflow download (GALAXY-MAIN-14JQ).
Creates a workflow with many steps and connections, then measures
the number of SQL statements emitted when accessing input_connections
with and without the selectinload fix.
"""
import logging
import threading
import time
import pytest
from sqlalchemy import event, select
from sqlalchemy.orm import joinedload, selectinload, subqueryload
from galaxy import model as m
from galaxy.model import mapper_registry
from galaxy.model.unittest_utils.model_testing_utils import initialize_model
from galaxy.managers.workflows import _get_stored_workflow
@pytest.fixture(scope="module")
def init_model(engine):
initialize_model(mapper_registry, engine)
class QueryCounter:
def __init__(self, engine):
self._local = threading.local()
event.listen(engine, "before_cursor_execute", self._callback)
def _callback(self, conn, cursor, statement, parameters, context, executemany):
self._local.queries = getattr(self._local, "queries", [])
self._local.queries.append(statement)
def reset(self):
self._local.queries = []
@property
def count(self):
return len(getattr(self._local, "queries", []))
@property
def queries(self):
return list(getattr(self._local, "queries", []))
NUM_STEPS = 20
def _build_workflow(session):
user = m.User(email="test@test.org", password="password")
session.add(user)
stored = m.StoredWorkflow()
stored.user = user
workflow = m.Workflow()
workflow.stored_workflow = stored
stored.latest_workflow = workflow
input_step = m.WorkflowStep()
input_step.type = "data_input"
input_step.order_index = 0
input_step.position = m.WorkflowStep.DEFAULT_POSITION
workflow.steps.append(input_step)
for i in range(1, NUM_STEPS):
step = m.WorkflowStep()
step.type = "tool"
step.tool_id = f"cat{i}"
step.order_index = i
step.position = m.WorkflowStep.DEFAULT_POSITION
step_input = m.WorkflowStepInput(step)
step_input.name = "input1"
conn = m.WorkflowStepConnection()
conn.output_step = input_step
conn.output_name = "output"
conn.input_step_input = step_input
step_input.connections = [conn]
step.inputs = [step_input]
workflow.steps.append(step)
session.add(stored)
session.commit()
return stored
def _run_old_query(session, stored_id):
"""Simulate the old query WITHOUT selectinload."""
StoredWorkflow = m.StoredWorkflow
Workflow = m.Workflow
stmt = select(StoredWorkflow).where(StoredWorkflow.id == stored_id)
stmt = stmt.options(
joinedload(StoredWorkflow.annotations),
joinedload(StoredWorkflow.tags),
subqueryload(StoredWorkflow.latest_workflow).joinedload(Workflow.steps).joinedload("*"),
).limit(1)
return session.scalars(stmt).first()
def test_n_plus_one_before_and_after(engine, session):
"""Compare query counts and wall time with/without the selectinload fix."""
stored = _build_workflow(session)
stored_id = stored.id
counter = QueryCounter(engine)
# ---- WITHOUT FIX ----
session.expunge_all()
counter.reset()
t0 = time.perf_counter()
reloaded = _run_old_query(session, stored_id)
load_old = counter.count
counter.reset()
for step in reloaded.latest_workflow.steps:
_ = step.input_connections
access_old = counter.count
time_old = time.perf_counter() - t0
# ---- WITH FIX (uses _get_stored_workflow which has selectinload) ----
session.expunge_all()
counter.reset()
t0 = time.perf_counter()
reloaded = _get_stored_workflow(session, workflow_uuid=None, workflow_id=stored_id, by_stored_id=True)
load_new = counter.count
# Log the SQL emitted during load
logging.info("--- SQL emitted by _get_stored_workflow (fixed) ---")
for i, q in enumerate(counter.queries):
logging.info(" [%d] %.200s", i + 1, q.strip().replace("\n", " "))
counter.reset()
for step in reloaded.latest_workflow.steps:
_ = step.input_connections
access_new = counter.count
time_new = time.perf_counter() - t0
total_old = load_old + access_old
total_new = load_new + access_new
logging.info("")
logging.info("=== N+1 Query Comparison (20 steps, 19 connections) ===")
logging.info("")
logging.info("WITHOUT fix: %d load + %d access = %d total (%.4fs)", load_old, access_old, total_old, time_old)
logging.info("WITH fix: %d load + %d access = %d total (%.4fs)", load_new, access_new, total_new, time_new)
logging.info("Queries eliminated: %d", total_old - total_new)
logging.info("")
assert access_new == 0, f"N+1 not fixed: {access_new} lazy queries still emitted"
assert access_old > 0, "Old query should have triggered lazy loads"
assert total_new < total_old, "Fixed query should use fewer total queries"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment