Skip to content

Instantly share code, notes, and snippets.

@benmccallum
Last active February 14, 2022 06:59
Show Gist options
  • Save benmccallum/5262ad72cdfca14b6b24a69ea6b09ba3 to your computer and use it in GitHub Desktop.
Save benmccallum/5262ad72cdfca14b6b24a69ea6b09ba3 to your computer and use it in GitHub Desktop.
Cursor-based paging for connections with MS SQL CTEs

Cursor-based paging with MS SQL with varying sort ordering

Uses SqlKata.

EF Core support:

-- Connection args:
-- first: 10
WITH [cte]
AS
(
SELECT
BookingID AS [Cursor],
ROW_NUMBER() OVER (ORDER BY ServiceDate DESC, BookingID DESC) AS [RowNumber]
FROM [Booking]
-- WHERE {filters}
)
SELECT
[Cursor],
[RowNumber],
(SELECT COUNT(*) FROM cte) AS TotalCount
FROM [cte]
WHERE [RowNumber] <= 10 -- limits to first x
-- client can determine:
-- totalCount = results.Any() ? results.First().TotalCount : 0;
-- hasNextPage = totalCount > 0 ? results.First().RowNumber != 1;
-- hasNextPage = totalCount > 0 && results.Last().RowNumber != totalCount;
-- Connection args:
-- first: 10
-- after: 173391 (last cursor of previous query)
WITH [cte]
AS
(
SELECT
BookingID AS RowCursor,
ROW_NUMBER() OVER (ORDER BY ServiceDate DESC, BookingID DESC) AS RowNumber
FROM [Booking]
--WHERE {filters}
)
SELECT
[RowCursor],
[RowNumber],
(SELECT COUNT(*) FROM cte) AS TotalCount
FROM [cte]
WHERE RowNumber > ( -- limits to rows after last of previous query
SELECT RowNumber
FROM [cte]
WHERE RowCursor = 173391
)
AND RowNumber <= ( -- limts to first x rows after last or previous query
SELECT [RowNumber] + 10
FROM [cte]
WHERE RowCursor = 173391
)
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading.Tasks;
using AutoGuru.Client.Shared;
using AutoGuru.Client.Shared.Dtos;
using AutoGuru.Client.Shared.Models;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using SqlKata;
using SqlKata.Execution;
namespace AutoGuru.Service.Shared.Querying
{
public class ConnectionPager
{
private readonly ILogger<ConnectionPager> _logger;
private readonly IConfiguration _configuration;
private const string CTE = "cte";
private const string RowCursor = "RowCursor";
private const string RowNumber = "RowNumber";
public ConnectionPager(ILogger<ConnectionPager> logger, IConfiguration configuration)
{
_logger = logger;
_configuration = configuration;
}
public async Task<TConnectionDto> GetConnectionAsync<TCursor, TNodeDto, TEdgeDto, TConnectionDto>(
ConnectionPaginationArguments pagingArgs,
QueryFactory db,
Query query,
string cursorColumnName,
Func<TNodeDto, TCursor> getCursorFunc,
Func<List<TCursor>, Task<TNodeDto[]>> getItemsAsyncFunc)
where TConnectionDto : ConnectionDto<TNodeDto, TEdgeDto>, new()
where TEdgeDto : EdgeDto<TNodeDto>, new()
{
var cteQuery = query.Clone();
// Get the order by, cursor-determining expression
var order = db.Compiler
.CompileOrders(new SqlResult
{
Query = cteQuery
});
if (string.IsNullOrWhiteSpace(order))
{
throw new Exception($"{nameof(query)} does not have an ORDER BY clause");
}
// Build a CTE from given query to form a dictionary of { RowCursor, RowNumber }
var shouldSelectTotalCount = pagingArgs.IsTotalCountRequested || pagingArgs.IsHasPrevPageRequested || pagingArgs.IsHasNextPageRequested;
cteQuery.Clauses.RemoveAll(c => c.Component == "order");
cteQuery.Clauses.RemoveAll(c => c.Component == "select");
cteQuery.SelectRaw(
$"{cursorColumnName} AS {RowCursor}, " +
$"ROW_NUMBER() OVER ({order}) AS {RowNumber}");
// Build an edges query using that CTE's dictionary to slice out the records we need leveraging RowNumber
var edgesQuery = db.Query()
.With(CTE, cteQuery)
.From(CTE)
.SelectRaw(
$"{RowCursor}, " +
$"{RowNumber}" +
(shouldSelectTotalCount ? $", (SELECT COUNT(*) FROM {CTE}) AS TotalCount" : ""));
// Filter by all rows after or for given after cursor
var hasAfterCursor = !string.IsNullOrWhiteSpace(pagingArgs.After);
var afterCursor = hasAfterCursor ? Cursor.FromCursor<TCursor>(pagingArgs.After) : default;
if (hasAfterCursor)
{
edgesQuery.Where(RowNumber, ">",
new Query(CTE)
.Select(RowNumber)
.Where(RowCursor, afterCursor));
}
// Filter by all rows before or for given before cursor
var hasBeforeCursor = !string.IsNullOrWhiteSpace(pagingArgs.Before);
var beforeCursor = hasBeforeCursor ? Cursor.FromCursor<TCursor>(pagingArgs.Before) : default;
if (hasBeforeCursor)
{
edgesQuery.Where(RowNumber, "<",
new Query(CTE)
.Select(RowNumber)
.Where(RowCursor, beforeCursor));
}
// Select the first x amount of rows
if (pagingArgs.First.HasValue)
{
// If the after cursor is defined
if (hasAfterCursor)
{
edgesQuery.Where(RowNumber, "<=",
new Query(CTE)
.SelectRaw($"{RowNumber} + {pagingArgs.First.Value}")
.Where(RowCursor, afterCursor));
}
// If no after cursor is defined
else
{
edgesQuery.Where(RowNumber, "<=", pagingArgs.First.Value);
}
}
else // Select the last x amount of rows
{
// If the before cursor is defined
if (hasBeforeCursor)
{
edgesQuery.Where(RowNumber, ">=",
new Query(CTE)
.SelectRaw($"{RowNumber} - {pagingArgs.Last.Value}")
.Where(RowCursor, beforeCursor));
}
// If we have to take data all the way from the back
else
{
edgesQuery.Where(RowNumber, ">",
new Query(CTE)
.SelectRaw($"COUNT(*) - {pagingArgs.Last.Value}"));
}
}
// Execute the edges query to determine the records we need (ids),
// total count and has next/prev page if requested
dynamic[] pageOfEdges;
try
{
var sw = Stopwatch.StartNew();
pageOfEdges = (await edgesQuery.GetAsync()).ToArray();
if (_logger.IsEnabled(LogLevel.Information))
{
// Wrapped in a log level check so we're not compiling the query unnecessarily
_logger.LogInformation("Executed DbCommand (SqlKata) ({elapsed}ms) {commandText}",
sw.ElapsedMilliseconds,
db.Compiler.Compile(edgesQuery));
}
}
catch (Exception ex)
{
// For debugging, you can get the generated SQL query here
var sqlResult = db.Compiler.Compile(edgesQuery);
_logger.LogError(ex, $"Error performing Connection DB query for page of edges. Query = {sqlResult.Sql}");
throw;
}
var totalCount = shouldSelectTotalCount && pageOfEdges.Any() ? (int)pageOfEdges.First().TotalCount : 0;
var hasPrevPage = totalCount > 0 && pageOfEdges.First().RowNumber != 1;
var hasNextPage = totalCount > 0 && pageOfEdges.Last().RowNumber != totalCount;
var cursors = pageOfEdges.Select(e => (TCursor)e.RowCursor).ToList();
// Execute the given EF query func to get the DTOs we need
TNodeDto[] edges;
try
{
edges = (await getItemsAsyncFunc(cursors))
.OrderBy(i => cursors.IndexOf(getCursorFunc(i)))
.ToArray();
}
catch (Exception ex)
{
_logger.LogError(ex, "Error performing Connection DB query for full DTOs.");
throw;
}
// Return as a connection dto
#pragma warning disable IDE0039 // Use local function, doesn't compile when local function
Func<TNodeDto, object> getCursorFunc2 = i => getCursorFunc(i);
#pragma warning restore IDE0039 // Use local function
return ConnectionDto.From<TNodeDto, TEdgeDto, TConnectionDto>(
edges,
getCursorFunc2,
hasPrevPage,
hasNextPage,
totalCount
);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment