Created
March 17, 2025 06:11
-
-
Save dfee/66e31eac57d91b8a2f938ef2d4319eb4 to your computer and use it in GitHub Desktop.
sqlx enum for execute over Connection, Pool, etc.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use futures_core::{future::BoxFuture, stream::BoxStream}; | |
use sqlx::{ | |
postgres::{PgQueryResult, PgRow, PgStatement, PgTypeInfo}, | |
Describe, Either, Error, Execute, Executor, PgConnection, PgPool, PgTransaction, Postgres, | |
}; | |
#[derive(Debug)] | |
enum XStore<'a> { | |
Connection(&'a mut PgConnection), | |
Pool(&'a PgPool), | |
} | |
impl<'c> Executor<'c> for &'c mut XStore<'c> { | |
type Database = Postgres; | |
fn fetch_many<'e, 'q, E>( | |
self, | |
mut query: E, | |
) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>> | |
where | |
'c: 'e, | |
E: Execute<'q, Self::Database>, | |
'q: 'e, | |
E: 'q, | |
{ | |
match self { | |
XStore::Connection(conn) => conn.fetch_many(query), | |
XStore::Pool(pool) => pool.fetch_many(query), | |
} | |
} | |
fn fetch_optional<'e, 'q, E>( | |
self, | |
mut query: E, | |
) -> BoxFuture<'e, Result<Option<PgRow>, Error>> | |
where | |
'c: 'e, | |
E: Execute<'q, Self::Database>, | |
'q: 'e, | |
E: 'q, | |
{ | |
match self { | |
XStore::Connection(conn) => conn.fetch_optional(query), | |
XStore::Pool(pool) => pool.fetch_optional(query), | |
} | |
} | |
fn prepare_with<'e, 'q: 'e>( | |
self, | |
sql: &'q str, | |
parameters: &'e [PgTypeInfo], | |
) -> BoxFuture<'e, Result<PgStatement<'q>, Error>> | |
where | |
'c: 'e, | |
{ | |
match self { | |
XStore::Connection(conn) => conn.prepare_with(sql, parameters), | |
XStore::Pool(pool) => pool.prepare_with(sql, parameters), | |
} | |
} | |
fn describe<'e, 'q: 'e>( | |
self, | |
sql: &'q str, | |
) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>> | |
where | |
'c: 'e, | |
{ | |
match self { | |
XStore::Connection(conn) => conn.describe(sql), | |
XStore::Pool(pool) => pool.describe(sql), | |
} | |
} | |
} | |
async fn get_value_xstore<'a>( | |
xstore: &'a mut XStore<'a>, | |
value: i64, | |
) -> Result<i64, sqlx::Error> { | |
sqlx::query_scalar::<_, i64>("SELECT $1") | |
.bind(value) | |
.fetch_one(xstore) | |
.await | |
} | |
#[cfg(test)] | |
mod tests { | |
use database::testing::get_pool; | |
use super::*; | |
#[tokio::test] | |
async fn test_pool() -> anyhow::Result<()> { | |
let pool = get_pool().await; | |
let mut xstore = XStore::Pool(&pool); | |
assert_eq!(get_value_xstore(&mut xstore, 150).await?, 150); | |
Ok(()) | |
} | |
#[tokio::test] | |
async fn test_connection() -> anyhow::Result<()> { | |
let pool = get_pool().await; | |
let mut conn = pool.acquire().await?; | |
let mut xstore = XStore::Connection(&mut conn); | |
assert_eq!(get_value_xstore(&mut xstore, 150).await?, 150); | |
Ok(()) | |
} | |
#[tokio::test] | |
async fn test_transaction() -> anyhow::Result<()> { | |
let pool = get_pool().await; | |
let mut txn = pool.begin().await?; | |
let mut xstore: XStore = XStore::Connection(&mut *txn); | |
assert_eq!(get_value_xstore(&mut xstore, 150).await?, 150); | |
txn.rollback().await?; | |
Ok(()) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment