Created
February 8, 2025 23:35
-
-
Save lukeramsden/79cd6a949a1afad1ce79154e00806147 to your computer and use it in GitHub Desktop.
Zapatos Cursor-based pagination helper
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import * as db from "zapatos/db"; | |
import * as s from "zapatos/schema"; | |
export interface Pagination { | |
limit?: number; | |
startingAfter?: string | null; | |
endingBefore?: string | null; | |
} | |
export interface PaginatedResult<T> { | |
items: T[]; | |
page: { | |
hasMore: boolean; | |
totalCount: number; | |
}; | |
} | |
export interface PaginationQueryOptions< | |
O, | |
T extends s.Table, | |
L extends db.LateralOption<undefined, E>, | |
E extends db.ExtrasOption<T> | |
> { | |
table: T; | |
where?: s.WhereableForTable<T>; | |
lateral?: db.SelectOptionsForTable<T, undefined, L, E, string>['lateral']; | |
extras?: db.SelectOptionsForTable<T, undefined, L, E, string>['extras']; | |
pagination?: Pagination; | |
orderBy: { | |
timestamp: [s.ColumnForTable<T>, keyof O], | |
typeId: [s.ColumnForTable<T>, keyof O] | |
}; | |
referenceItemFunc: (id: string) => Promise<O | null>; | |
mapResult: (row: s.JSONSelectableForTable<T> & | |
(undefined extends L ? Record<string, never> : L extends db.SQLFragmentMap ? db.LateralResult<L> : never) & | |
(undefined extends E ? Record<string, never> : E extends db.SQLFragmentOrColumnMap<T> ? db.ExtrasResult<T, E> : never) | |
) => O; | |
} | |
export async function paginateQuery<O, T extends s.Table, L extends db.LateralOption<undefined, E>, E extends db.ExtrasOption<T>>( | |
queryable: db.Queryable, | |
options: PaginationQueryOptions<O, T, L, E>, | |
): Promise<PaginatedResult<O>> { | |
const { startingAfter, endingBefore } = options.pagination || {}; | |
const defaultLimit = 20; | |
const maxLimit = 100; | |
const minLimit = 1; | |
const requestedLimit = options.pagination?.limit ?? defaultLimit; | |
const limit = Math.max(Math.min(requestedLimit, maxLimit), minLimit); | |
if (startingAfter && endingBefore) { | |
throw new Error("Cannot paginate with both startingAfter and endingBefore"); | |
} | |
let entity: O | null = null; | |
if (startingAfter || endingBefore) { | |
entity = await options.referenceItemFunc(startingAfter || endingBefore || ""); | |
if (!entity) { | |
// TODO: should we throw an invalid cursor error here instead? | |
return { items: [], page: { hasMore: false, totalCount: 0 } }; | |
} | |
} | |
const { timestamp, typeId } = options.orderBy; | |
const [timestampColumn, timestampField] = timestamp; | |
const [typeIdColumn, typeIdField] = typeId; | |
if (timestampColumn in (options.where ?? {})) { | |
throw new Error(`Cannot filter by column ${timestampColumn} (timestampColumn) when paginating`); | |
} | |
if (typeIdColumn in (options.where ?? {})) { | |
throw new Error(`Cannot filter by column ${typeIdColumn} (typeIdColumn) when paginating`); | |
} | |
const paginationFilter: s.WhereableForTable<T> = | |
startingAfter && entity ? { | |
[timestampColumn]: db.sql`(${db.self}, ${typeIdColumn}) < (${db.param(entity[timestampField])}, typeid_parse(${db.param(entity[typeIdField])}))`, | |
} : | |
endingBefore && entity ? { | |
[timestampColumn]: db.sql`(${db.self}, ${typeIdColumn}) > (${db.param(entity[timestampField])}, typeid_parse(${db.param(entity[typeIdField])}))`, | |
} : {}; | |
const [totalCount, rows] = await Promise.all([ | |
db.count(options.table, options.where ?? db.all).run(queryable), | |
db.select( | |
options.table, | |
{ | |
...(options.where ?? {}), | |
...paginationFilter, | |
}, | |
{ | |
limit: limit + 1, // Get one extra to determine if there are more results | |
order: [ | |
{ | |
by: timestampColumn, | |
direction: endingBefore ? "ASC" : "DESC", | |
nulls: "LAST", | |
}, { | |
by: typeIdColumn, | |
direction: endingBefore ? "ASC" : "DESC", | |
nulls: "LAST", | |
} | |
], | |
lateral: options.lateral, | |
extras: options.extras, | |
}).run(queryable) | |
]); | |
//TODO: the row type casting is not nice but it works for now | |
let results = rows.map((row) => options.mapResult(row as s.JSONSelectableForTable<T> & | |
(undefined extends L ? Record<string, never> : L extends db.SQLFragmentMap ? db.LateralResult<L> : never) & | |
(undefined extends E ? Record<string, never> : E extends db.SQLFragmentOrColumnMap<T> ? db.ExtrasResult<T, E> : never))); | |
// Check if there are more results by seeing if we got an extra item, and then cutting it off if so | |
const hasMore = results.length > limit; | |
if (hasMore) { | |
results = results.slice(0, limit); | |
} | |
// If we're paginating backwards, we need to reverse the results to maintain | |
// consistent ordering with forward pagination | |
if (endingBefore) { | |
results = results.reverse(); | |
} | |
return { items: results, page: { hasMore, totalCount } }; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment