Created
April 12, 2026 08:02
-
-
Save emilpriver/6ce4e90ab2fd70bc969ca77b278fb818 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use sqlx::Row; | |
| use sqlx::sqlite::SqlitePool; | |
| use std::future::Future; | |
| use suitcase::{Case, HookFns, RunConfig, cases_fn, run}; | |
| use tokio::runtime::{Handle, Runtime}; | |
| struct DbSuite { | |
| handle: Handle, | |
| pool: Option<SqlitePool>, | |
| before_each_calls: u32, | |
| } | |
| impl DbSuite { | |
| fn new(handle: Handle) -> Self { | |
| Self { | |
| handle, | |
| pool: None, | |
| before_each_calls: 0, | |
| } | |
| } | |
| fn pool(&self) -> &SqlitePool { | |
| self.pool | |
| .as_ref() | |
| .expect("pool must be initialized by setup_suite before hooks or cases run") | |
| } | |
| fn block_on<F>(&self, fut: F) -> F::Output | |
| where | |
| F: Future + Send, | |
| F::Output: Send, | |
| { | |
| self.handle.block_on(fut) | |
| } | |
| async fn count_rows(pool: &SqlitePool, table: &str) -> i64 { | |
| let q = format!("SELECT COUNT(*) AS c FROM {table}"); | |
| let row = sqlx::query(&q).fetch_one(pool).await.unwrap(); | |
| row.get::<i64, _>("c") | |
| } | |
| } | |
| fn test_insert_alice(s: &mut DbSuite) { | |
| s.block_on(async { | |
| sqlx::query("INSERT INTO users (name, version) VALUES (?, ?)") | |
| .bind("alice") | |
| .bind(1i64) | |
| .execute(s.pool()) | |
| .await | |
| .unwrap(); | |
| }); | |
| } | |
| fn test_insert_bob(s: &mut DbSuite) { | |
| s.block_on(async { | |
| sqlx::query("INSERT INTO users (name, version) VALUES (?, ?)") | |
| .bind("bob") | |
| .bind(1i64) | |
| .execute(s.pool()) | |
| .await | |
| .unwrap(); | |
| let n = DbSuite::count_rows(s.pool(), "users").await; | |
| assert_eq!(n, 2, "alice then bob"); | |
| }); | |
| } | |
| fn test_bump_alice_version(s: &mut DbSuite) { | |
| s.block_on(async { | |
| sqlx::query("UPDATE users SET version = ? WHERE name = ?") | |
| .bind(2i64) | |
| .bind("alice") | |
| .execute(s.pool()) | |
| .await | |
| .unwrap(); | |
| let row = sqlx::query("SELECT version FROM users WHERE name = ?") | |
| .bind("alice") | |
| .fetch_one(s.pool()) | |
| .await | |
| .unwrap(); | |
| assert_eq!(row.get::<i64, _>("version"), 2); | |
| }); | |
| } | |
| fn test_assert_final_counts(s: &mut DbSuite) { | |
| s.block_on(async { | |
| assert_eq!(DbSuite::count_rows(s.pool(), "users").await, 2); | |
| let logs = DbSuite::count_rows(s.pool(), "op_log").await; | |
| assert!(logs >= 1, "at least setup + hooks + cases"); | |
| }); | |
| } | |
| fn db_setup_suite(s: &mut DbSuite) { | |
| s.pool = Some(s.block_on(async { | |
| let pool = SqlitePool::connect("sqlite::memory:") | |
| .await | |
| .expect("sqlite connect"); | |
| sqlx::migrate!("tests/sqlx_sqlite_migrations") | |
| .run(&pool) | |
| .await | |
| .expect("apply migrations"); | |
| sqlx::query("INSERT INTO op_log (message) VALUES ('setup_suite')") | |
| .execute(&pool) | |
| .await | |
| .unwrap(); | |
| pool | |
| })); | |
| } | |
| fn db_teardown_suite(s: &mut DbSuite) { | |
| s.block_on(async { | |
| sqlx::query("INSERT INTO op_log (message) VALUES ('teardown_suite')") | |
| .execute(s.pool()) | |
| .await | |
| .unwrap(); | |
| }); | |
| } | |
| fn db_before_each(s: &mut DbSuite) { | |
| s.before_each_calls += 1; | |
| let n = s.before_each_calls; | |
| s.block_on(async { | |
| sqlx::query("INSERT INTO op_log (message) VALUES (?)") | |
| .bind(format!("before_each #{n}")) | |
| .execute(s.pool()) | |
| .await | |
| .unwrap(); | |
| }); | |
| } | |
| fn db_after_each(s: &mut DbSuite) { | |
| s.block_on(async { | |
| sqlx::query("INSERT INTO op_log (message) VALUES ('after_each')") | |
| .execute(s.pool()) | |
| .await | |
| .unwrap(); | |
| }); | |
| } | |
| static DB_SUITE_HOOKS: HookFns<DbSuite> = HookFns { | |
| setup_suite: Some(db_setup_suite), | |
| teardown_suite: Some(db_teardown_suite), | |
| before_each: Some(db_before_each), | |
| after_each: Some(db_after_each), | |
| }; | |
| static DB_SUITE_CASES: &[Case<DbSuite>] = cases_fn![ | |
| DbSuite => | |
| test_insert_alice => test_insert_alice, | |
| test_insert_bob => test_insert_bob, | |
| test_bump_alice_version => test_bump_alice_version, | |
| test_assert_final_counts => test_assert_final_counts, | |
| ]; | |
| #[test] | |
| fn sqlx_sqlite_suite_mutates_db_between_cases() -> Result<(), Box<dyn std::error::Error>> { | |
| let rt = Runtime::new()?; | |
| let mut suite = DbSuite::new(rt.handle().clone()); | |
| run( | |
| &mut suite, | |
| DB_SUITE_CASES, | |
| RunConfig::all(), | |
| &DB_SUITE_HOOKS, | |
| ); | |
| let users = rt.block_on(DbSuite::count_rows(suite.pool(), "users")); | |
| let ops = rt.block_on(DbSuite::count_rows(suite.pool(), "op_log")); | |
| assert_eq!(users, 2); | |
| assert!(ops >= 10, "expected hook + case log rows, got {ops}"); | |
| println!("done: users={users}, op_log rows={ops}"); | |
| Ok(()) | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment