Created
March 26, 2026 15:05
-
-
Save mvdbeek/880b1b4929b8b0d5f2f2ee041717c6c8 to your computer and use it in GitHub Desktop.
Test demonstrating N+1 fix for workflow download (GALAXY-MAIN-14JQ)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Demonstrate the N+1 fix for workflow download (GALAXY-MAIN-14JQ). | |
| Creates a workflow with many steps and connections, then measures | |
| the number of SQL statements emitted when accessing input_connections | |
| with and without the selectinload fix. | |
| """ | |
| import logging | |
| import threading | |
| import time | |
| import pytest | |
| from sqlalchemy import event, select | |
| from sqlalchemy.orm import joinedload, selectinload, subqueryload | |
| from galaxy import model as m | |
| from galaxy.model import mapper_registry | |
| from galaxy.model.unittest_utils.model_testing_utils import initialize_model | |
| from galaxy.managers.workflows import _get_stored_workflow | |
| @pytest.fixture(scope="module") | |
| def init_model(engine): | |
| initialize_model(mapper_registry, engine) | |
| class QueryCounter: | |
| def __init__(self, engine): | |
| self._local = threading.local() | |
| event.listen(engine, "before_cursor_execute", self._callback) | |
| def _callback(self, conn, cursor, statement, parameters, context, executemany): | |
| self._local.queries = getattr(self._local, "queries", []) | |
| self._local.queries.append(statement) | |
| def reset(self): | |
| self._local.queries = [] | |
| @property | |
| def count(self): | |
| return len(getattr(self._local, "queries", [])) | |
| @property | |
| def queries(self): | |
| return list(getattr(self._local, "queries", [])) | |
| NUM_STEPS = 20 | |
| def _build_workflow(session): | |
| user = m.User(email="test@test.org", password="password") | |
| session.add(user) | |
| stored = m.StoredWorkflow() | |
| stored.user = user | |
| workflow = m.Workflow() | |
| workflow.stored_workflow = stored | |
| stored.latest_workflow = workflow | |
| input_step = m.WorkflowStep() | |
| input_step.type = "data_input" | |
| input_step.order_index = 0 | |
| input_step.position = m.WorkflowStep.DEFAULT_POSITION | |
| workflow.steps.append(input_step) | |
| for i in range(1, NUM_STEPS): | |
| step = m.WorkflowStep() | |
| step.type = "tool" | |
| step.tool_id = f"cat{i}" | |
| step.order_index = i | |
| step.position = m.WorkflowStep.DEFAULT_POSITION | |
| step_input = m.WorkflowStepInput(step) | |
| step_input.name = "input1" | |
| conn = m.WorkflowStepConnection() | |
| conn.output_step = input_step | |
| conn.output_name = "output" | |
| conn.input_step_input = step_input | |
| step_input.connections = [conn] | |
| step.inputs = [step_input] | |
| workflow.steps.append(step) | |
| session.add(stored) | |
| session.commit() | |
| return stored | |
| def _run_old_query(session, stored_id): | |
| """Simulate the old query WITHOUT selectinload.""" | |
| StoredWorkflow = m.StoredWorkflow | |
| Workflow = m.Workflow | |
| stmt = select(StoredWorkflow).where(StoredWorkflow.id == stored_id) | |
| stmt = stmt.options( | |
| joinedload(StoredWorkflow.annotations), | |
| joinedload(StoredWorkflow.tags), | |
| subqueryload(StoredWorkflow.latest_workflow).joinedload(Workflow.steps).joinedload("*"), | |
| ).limit(1) | |
| return session.scalars(stmt).first() | |
| def test_n_plus_one_before_and_after(engine, session): | |
| """Compare query counts and wall time with/without the selectinload fix.""" | |
| stored = _build_workflow(session) | |
| stored_id = stored.id | |
| counter = QueryCounter(engine) | |
| # ---- WITHOUT FIX ---- | |
| session.expunge_all() | |
| counter.reset() | |
| t0 = time.perf_counter() | |
| reloaded = _run_old_query(session, stored_id) | |
| load_old = counter.count | |
| counter.reset() | |
| for step in reloaded.latest_workflow.steps: | |
| _ = step.input_connections | |
| access_old = counter.count | |
| time_old = time.perf_counter() - t0 | |
| # ---- WITH FIX (uses _get_stored_workflow which has selectinload) ---- | |
| session.expunge_all() | |
| counter.reset() | |
| t0 = time.perf_counter() | |
| reloaded = _get_stored_workflow(session, workflow_uuid=None, workflow_id=stored_id, by_stored_id=True) | |
| load_new = counter.count | |
| # Log the SQL emitted during load | |
| logging.info("--- SQL emitted by _get_stored_workflow (fixed) ---") | |
| for i, q in enumerate(counter.queries): | |
| logging.info(" [%d] %.200s", i + 1, q.strip().replace("\n", " ")) | |
| counter.reset() | |
| for step in reloaded.latest_workflow.steps: | |
| _ = step.input_connections | |
| access_new = counter.count | |
| time_new = time.perf_counter() - t0 | |
| total_old = load_old + access_old | |
| total_new = load_new + access_new | |
| logging.info("") | |
| logging.info("=== N+1 Query Comparison (20 steps, 19 connections) ===") | |
| logging.info("") | |
| logging.info("WITHOUT fix: %d load + %d access = %d total (%.4fs)", load_old, access_old, total_old, time_old) | |
| logging.info("WITH fix: %d load + %d access = %d total (%.4fs)", load_new, access_new, total_new, time_new) | |
| logging.info("Queries eliminated: %d", total_old - total_new) | |
| logging.info("") | |
| assert access_new == 0, f"N+1 not fixed: {access_new} lazy queries still emitted" | |
| assert access_old > 0, "Old query should have triggered lazy loads" | |
| assert total_new < total_old, "Fixed query should use fewer total queries" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment