Created
July 9, 2022 10:56
-
-
Save rj76/3a0f2635473a54e3f880cd14ea46e84e to your computer and use it in GitHub Desktop.
Diesel pagination
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use diesel::prelude::*; | |
use diesel::pg::Pg; | |
use diesel::query_builder::*; | |
use diesel::query_dsl::methods::LoadQuery; | |
use diesel::sql_types::{BigInt}; | |
const DEFAULT_PAGE_SIZE: i64 = 20; | |
#[derive(QueryId)] | |
pub struct Paginated<T> { | |
query: T, | |
page: i64, | |
page_size: i64, | |
} | |
pub trait Paginate: Sized { | |
fn paginate(self, page: i64) -> Paginated<Self>; | |
} | |
impl<T> Paginate for T { | |
fn paginate(self, page: i64) -> Paginated<Self> { | |
Paginated { | |
query: self, | |
page_size: DEFAULT_PAGE_SIZE, | |
page, | |
} | |
} | |
} | |
impl<T> QueryFragment<Pg> for Paginated<T> | |
where | |
T: QueryFragment<Pg>, | |
{ | |
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> { | |
out.push_sql("SELECT *, COUNT(*) OVER () FROM ("); | |
self.query.walk_ast(out.reborrow())?; | |
out.push_sql(") t LIMIT "); | |
out.push_bind_param::<BigInt, _>(&self.page_size)?; | |
out.push_sql(" OFFSET "); | |
let offset = (self.page - 1) * self.page_size; | |
out.push_bind_param::<BigInt, _>(&offset)?; | |
Ok(()) | |
} | |
} | |
impl<T: Query> Query for Paginated<T> { | |
type SqlType = (T::SqlType, BigInt); | |
} | |
impl<T> RunQueryDsl<PgConnection> for Paginated<T> {} | |
impl<T> Paginated<T> { | |
pub fn page_size(self, page_size: i64) -> Self { | |
Paginated { page_size, ..self } | |
} | |
pub fn load_and_count_pages<'a, U>(self, conn: &mut PgConnection) -> QueryResult<(Vec<U>, i64, i64)> | |
where | |
Self: LoadQuery<'a, PgConnection, (U, i64)>, | |
{ | |
let page_size = self.page_size; | |
let results = self.load::<(U, i64)>(conn)?; | |
let total = results.get(0).map(|x| x.1).unwrap_or(0); | |
let records = results.into_iter().map(|x| x.0).collect(); | |
let total_pages = (total as f64 / page_size as f64).ceil() as i64; | |
Ok((records, total, total_pages)) | |
} | |
} | |
pub trait LoadPaginated<'a, U>: Query + QueryId + QueryFragment<Pg> + LoadQuery<'a, PgConnection, U> { | |
fn load_with_pagination(self, conn: &mut PgConnection, page: Option<i64>, page_size: Option<i64>) -> QueryResult<(Vec<U>, i64, i64)>; | |
} | |
impl<'a, T, U> LoadPaginated<'a, U> for T | |
where | |
Self: Query + QueryId + QueryFragment<Pg> + LoadQuery<'a, PgConnection, U>, | |
U: Queryable<Self::SqlType, Pg>, | |
// IDE suggested this U: Queryable<Self::SqlType, Pg>, (<T as diesel::query_builder::Query>::SqlType, BigInt): load_dsl::private::CompatibleType<(U, i64), Pg> | |
{ | |
fn load_with_pagination(self, conn: &mut PgConnection, page: Option<i64>, page_size: Option<i64>) -> QueryResult<(Vec<U>, i64, i64)> { | |
let (records, total_pages, total) = match page { | |
Some(page) => { | |
let mut query = self.paginate(page); | |
if let Some(page_size) = page_size { | |
query = query.page_size(page_size); | |
} | |
query.load_and_count_pages::<U>(conn)? | |
}, | |
None => (self.load::<U>(conn)?, 1, 1), | |
}; | |
Ok((records, total_pages, total)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment