Skip to content

Instantly share code, notes, and snippets.

@egil
Last active April 1, 2024 14:11
Show Gist options
  • Save egil/b22c4db106c880c41cecc2cae64d3ca2 to your computer and use it in GitHub Desktop.
Save egil/b22c4db106c880c41cecc2cae64d3ca2 to your computer and use it in GitHub Desktop.
A helper class DatabaseMemoryCache that uses Microsoft.Extensions.Caching.Memory.IMemoryCaching to cache async tasks their result data, e.g. EF core queries. The CustomerRepository below shows an example of how to use it.
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Caching.Memory;
namespace Solution.Data;
public sealed class CustomerRepository : IDisposable
{
private readonly IDbContextFactory<DbContext> dbFactory;
private readonly DatabaseMemoryCache databaseCache;
public CustomerRepository(
IDbContextFactory<DbContext> dbFactory,
IMemoryCache memoryCache,
ILogger<CompanyRepository> logger)
{
this.dbFactory = dbFactory;
databaseCache = new DatabaseMemoryCache(
static options => options.SetPriority(CacheItemPriority.High),
memoryCache,
logger);
}
public void Dispose()
{
databaseCache.Dispose();
}
public override async Task<Customer> CreateOrUpdateAsync(Customer customer)
{
using var dbContext = await dbFactory.CreateDbContextAsync();
dbContext.Customers.Update(customer);
var changes = await dbContext.SaveChangesAsync();
if (changes > 0)
{
databaseCache.InvalidateCache<Query>(query => query.CompanyId == result.CompanyId);
}
return customer;
}
public override async Task DeleteAsync(int customerId)
{
using var dbContext = await dbFactory.CreateDbContextAsync();
var customer = await dbContext.Customers.FirstOrDefaultAsync(x => x.Id == customerId);
if (customer is not null)
{
dbContext.Remove(customer);
var result = await dbContext.SaveChangesAsync();
if (result > 0)
{
databaseCache.InvalidateCache<Query>(query => query.CompanyId == customer.CompanyId);
}
}
}
public override Task<IReadOnlyList<Customer>> GetByCompanyAsync(string companyId, bool includeArchived = false, CancellationToken cancellationToken = default)
{
companyId = companyId.ToUpperInvariant();
return databaseCache.GetAsync(
new Query(companyId, includeArchived),
LoadCustomers,
cancellationToken);
}
private async Task<IReadOnlyList<Customer>> LoadCustomers(Query query, CancellationToken cancellationToken)
{
using var dbContext = await dbFactory.CreateDbContextAsync(cancellationToken);
var result = query.IncludeArchived
? dbContext.Customers
: dbContext.Customers.Where(x => !x.Archived);
return await result
.Where(x => x.CompanyId == query.CompanyId)
.OrderBy(x => x.Name)
.AsNoTracking()
.ToListAsync(cancellationToken);
}
private sealed record class Query(string CompanyId, bool IncludeArchived)
{
public string Kind { get; } = nameof(CustomerRepository);
}
}
using System.Collections.Concurrent;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Primitives;
namespace Microsoft.Extensions.Caching.Memory;
// It allows you to provide a key and a method that creates an (expensive) Task<T> whose
// result can be shared. It stores said Task<T> in the cache and supports the invalidation
// of the cache item through a method call. You can customize the cache entry options too.
public sealed partial class DatabaseMemoryCache : IDisposable
{
private readonly SemaphoreSlim cacheLock = new(1, 1);
private readonly ILogger logger;
private readonly Action<MemoryCacheEntryOptions> configureCacheEntry;
private readonly IMemoryCache memoryCache;
private readonly ConcurrentDictionary<object, CancellationTokenSource> cacheInvalidationTokens = new();
private bool disposed;
public DatabaseMemoryCache(
IMemoryCache cache,
ILogger logger)
: this(
static (options) => options.SetPriority(CacheItemPriority.Normal).SetSlidingExpiration(TimeSpan.FromMinutes(5)),
cache,
logger)
{
}
public DatabaseMemoryCache(
Action<MemoryCacheEntryOptions> configureCacheEntry,
IMemoryCache cache,
ILogger logger)
{
this.logger = logger;
this.configureCacheEntry = configureCacheEntry;
memoryCache = cache;
}
public void Dispose()
{
if (disposed)
{
return;
}
disposed = true;
foreach (var cacheKey in cacheInvalidationTokens.Keys)
{
EvictAndDispose(cacheKey);
}
cacheLock.Dispose();
}
public async Task<TResult> GetAsync<TCacheKey, TResult>(
TCacheKey cacheKey,
Func<TCacheKey, CancellationToken, Task<TResult>> createResultTask,
CancellationToken cancellationToken)
where TCacheKey : notnull
{
ObjectDisposedException.ThrowIf(disposed, this);
Task<TResult>? result = null;
await cacheLock.WaitAsync(cancellationToken);
try
{
result = GetOrCreateResultTask(cacheKey, createResultTask, cancellationToken);
}
finally
{
cacheLock.Release();
}
return await result;
}
public void InvalidateCache<TCacheKey>(Predicate<TCacheKey> invalidatePredicate)
{
ObjectDisposedException.ThrowIf(disposed, this);
foreach (var item in cacheInvalidationTokens)
{
if (item.Key is not TCacheKey cacheKey)
{
continue;
}
if (invalidatePredicate(cacheKey))
{
EvictAndDispose(item.Key);
}
}
}
private void EvictAndDispose(object cacheInvalidationTokenKey)
{
if (!cacheInvalidationTokens.TryRemove(cacheInvalidationTokenKey, out var cts))
{
return;
}
if (!cts.IsCancellationRequested)
{
cts.Cancel();
}
cts.Dispose();
}
private Task<TResult> GetOrCreateResultTask<TCacheKey, TResult>(
TCacheKey cacheKey,
Func<TCacheKey, CancellationToken, Task<TResult>> createResultTask,
CancellationToken userCancellationToken)
where TCacheKey : notnull
{
if (memoryCache.Get<CachedLoad<TCacheKey, TResult>>(cacheKey) is { } value)
{
if (!value.Result.IsFaulted && !value.Result.IsCanceled)
{
LogDataFoundInCache(cacheKey);
return value.Result;
}
else
{
if (value.Result.IsFaulted && !value.Result.IsCanceled)
{
LogDataLoadFailed(value.Result.Exception);
}
}
}
var cts = cacheInvalidationTokens.AddOrUpdate(
cacheKey,
_ => CancellationTokenSource.CreateLinkedTokenSource(userCancellationToken),
(key, existing) =>
{
existing.Dispose();
return CancellationTokenSource.CreateLinkedTokenSource(userCancellationToken);
});
value = new CachedLoad<TCacheKey, TResult>(cacheKey, createResultTask(cacheKey, userCancellationToken));
TryCacheValue(cacheKey, value, cts);
return value.Result;
}
private void TryCacheValue<TCacheKey, TResult>(TCacheKey cacheKey, CachedLoad<TCacheKey, TResult> value, CancellationTokenSource cts)
where TCacheKey : notnull
{
try
{
if (value.Result.Status is not TaskStatus.Canceled or TaskStatus.Faulted)
{
memoryCache.Set(cacheKey, value, CreateCacheEntryOptions(cts.Token));
}
}
catch (ObjectDisposedException) when (cts.IsCancellationRequested)
{
// Dont cache faulted or canceled requests
}
catch (OperationCanceledException) when (cts.IsCancellationRequested)
{
// Dont cache faulted or canceled requests
}
}
private MemoryCacheEntryOptions CreateCacheEntryOptions(CancellationToken cancellationToken)
{
var options = new MemoryCacheEntryOptions()
.AddExpirationToken(new CancellationChangeToken(cancellationToken))
.RegisterPostEvictionCallback(callback: (key, value, reason, state) =>
{
LogDataEvictedFromCache(key, reason);
EvictAndDispose(key);
});
configureCacheEntry(options);
return options;
}
[LoggerMessage(
Message = "Data load failed unexpectedly.",
Level = LogLevel.Error)]
private partial void LogDataLoadFailed(Exception exception);
[LoggerMessage(
Message = "Data for `{CacheKey}` found in cache.",
Level = LogLevel.Debug)]
private partial void LogDataFoundInCache(object cacheKey);
[LoggerMessage(
Message = "Data for `{CacheKey}` evicted from cache. {Reason}.",
Level = LogLevel.Debug)]
private partial void LogDataEvictedFromCache(object cacheKey, EvictionReason reason);
private sealed record class CachedLoad<TCacheKey, TResult>(TCacheKey CacheKey, Task<TResult> Result);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment