Skip to content

Instantly share code, notes, and snippets.

@jrgcubano
Forked from StephenCleary/AsyncCache.cs
Created April 15, 2024 10:22
Show Gist options
  • Save jrgcubano/a8bddeb9492656040f2756190627d970 to your computer and use it in GitHub Desktop.
Save jrgcubano/a8bddeb9492656040f2756190627d970 to your computer and use it in GitHub Desktop.
Asynchronous cache
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;
using Nito.Logging;
/// <summary>
/// Provides an asynchronous cache with exactly-once creation method semantics and flexible cache entries.
/// </summary>
public sealed class AsyncCache
{
private readonly object _mutex = new();
private readonly IMemoryCache _cache;
private readonly ILogger<AsyncCache> _logger;
/// <summary>
/// Creates an asynchronous cache wrapping an existing memory cache.
/// </summary>
public AsyncCache(IMemoryCache cache, ILogger<AsyncCache> logger)
{
_cache = cache;
_logger = logger;
}
internal static void Register(IServiceCollection services)
{
if (services.Any(x => x.ServiceType == typeof(AsyncCache)))
return;
services.AddMemoryCache();
services.AddLogging();
services.AddSingleton<AsyncCache>();
}
/// <summary>
/// Removes an item from the cache.
/// </summary>
/// <param name="key">The key of the item.</param>
public void Remove(object key)
{
using var _ = _logger.BeginDataScope(new {cacheKey = key});
_logger.LogTrace("Removing entry.");
_cache.Remove(key);
}
/// <summary>
/// Removes a specific future from the cache.
/// </summary>
/// <typeparam name="T">The type of the item.</typeparam>
/// <param name="key">The key of the item.</param>
/// <param name="value">The future that has to match the entry.</param>
public bool TryRemove<T>(object key, Task<T> value)
{
_ = value ?? throw new ArgumentNullException(nameof(value));
using var __ = _logger.BeginDataScope(new {cacheKey = key, taskId = value.Id});
lock (_mutex)
{
var existingTask = _cache.TryGetValue(key, out TaskCompletionSource<T> tcs) ? tcs.Task : null;
if (existingTask != value)
{
if (existingTask == null)
_logger.LogTrace("Attempted to remove entry, but it was already removed.");
else
_logger.LogTrace("Attempted to remove entry, but it had already been replaced by {existingTaskId}.", existingTask.Id);
return false;
}
_logger.LogTrace("Removing entry.");
_cache.Remove(key);
return true;
}
}
/// <summary>
/// Attempts to retrieve an item from the cache.
/// </summary>
/// <typeparam name="T">The type of the item.</typeparam>
/// <param name="key">The key of the item.</param>
/// <param name="task">On return, contains a future item.</param>
public bool TryGet<T>(object key, out Task<T>? task)
{
using var _ = _logger.BeginDataScope(new {cacheKey = key});
task = _cache.TryGetValue(key, out TaskCompletionSource<T> tcs) ? tcs.Task : null;
if (task == null)
_logger.LogTrace("Attempted to retrieve entry, but it was not found.");
else
_logger.LogTrace("Retrieved entry {taskId}.", task.Id);
return task != null;
}
/// <summary>
/// Atomically retrieves or creates a cache item.
/// </summary>
/// <typeparam name="T">The type of the item.</typeparam>
/// <param name="key">The key of the item.</param>
/// <param name="create">An asynchronous creation method. This method will only be invoked once. The creation method may control the cache entry behavior for the resulting value by using its <see cref="ICacheEntry"/> parameter; the <see cref="ICacheEntry.Value"/> member is ignored, but all other members are honored.</param>
/// <returns>A future item.</returns>
public Task<T> GetOrCreateAsync<T>(object key, Func<ICacheEntry, Task<T>> create)
{
using var _ = _logger.BeginDataScope(new {cacheKey = key});
TaskCompletionSource<T> tcs;
CancellationTokenSource cts;
lock (_mutex)
{
if (_cache.TryGetValue(key, out tcs))
{
_logger.LogTrace("GetOrCreate found existing entry {taskId}.", tcs.Task.Id);
return tcs.Task;
}
tcs = new TaskCompletionSource<T>();
using var entry = SafeCacheEntry.Create(_cache, key).SetSize(1);
cts = new CancellationTokenSource();
entry.Value = tcs;
entry.RegisterPostEvictionCallback((_, _, _, _) => cts.Dispose());
entry.AddExpirationToken(new CancellationChangeToken(cts.Token));
_logger.LogTrace("GetOrCreate creating new entry {taskId}.", tcs.Task.Id);
}
InvokeAndPropagateCompletion(create, SafeCacheEntry.Create(_cache, key).SetSize(1), tcs, cts);
return tcs.Task;
}
/// <summary>
/// Invokes the creation method and (asynchronously) updates the cache entry with the results.
/// - If the function succeeds synchronously, the cache entry is updated and the TCS completed by the time this method returns.
/// - If the function fails synchronously, the cache entry is removed and the TCS faulted by the time this method returns.
/// - If the function succeeds asynchronously, the cache entry is updated when the function completes *if* the cache entry has not changed by that time.
/// - If the function faults asynchronously, the cache entry is removed when the function completes *if* the cache entry has not changed by that time.
/// </summary>
/// <typeparam name="T">The type of object created by the <paramref name="create"/> method.</typeparam>
/// <param name="create">The creation method, which may update the cache entry set when the creation method completes. The <see cref="ICacheEntry.Value"/> member is ignored, but all other members are honored.</param>
/// <param name="cacheEntry">The cache entry that will be used to replace the cache entry currently containing <paramref name="tcs"/> if the creation succeeds.</param>
/// <param name="tcs">The task completion source currently in the cache entry. This method will (eventually) complete this task completion source.</param>
/// <param name="cts">The cancellation token source used to evict the current cache entry if the creation method fails.</param>
private async void InvokeAndPropagateCompletion<T>(Func<ICacheEntry, Task<T>> create, ICacheEntry cacheEntry, TaskCompletionSource<T> tcs, CancellationTokenSource cts)
{
try
{
// Asynchronously create the value.
var result = await create(cacheEntry);
// Atomically:
// - Check to see if we're still the one in the cache, and
// - If we are, update the cache entry with a new one having the same TCS value, but including new expiration and other settings from the creation method.
lock (_mutex)
{
// This check is necessary to avoid a race condition where our entry has been removed and re-created.
// In that case, there will be a cache entry but it will not be our cache entry, so we should not replace it.
// Rather, we'll leave the cache as-is (without our entry) and just complete our listeners (via TrySetResult, below).
if (_cache.TryGetValue(cacheEntry.Key, out TaskCompletionSource<T> existingTcs) && existingTcs == tcs)
{
_logger.LogTrace("GetOrCreate creation completed successfully; updating entry {taskId}.", tcs.Task.Id);
using (cacheEntry)
cacheEntry.Value = tcs;
}
else
{
if (existingTcs == null)
_logger.LogTrace("GetOrCreate creation completed successfully, but entry {taskId} has been removed.", tcs.Task.Id);
else
_logger.LogTrace("GetOrCreate creation completed successfully, but entry {taskId} has been replaced by entry {replacementTaskId}.", tcs.Task.Id, existingTcs.Task.Id);
}
}
// Propagate the result to any listeners.
tcs.TrySetResult(result);
}
catch (OperationCanceledException oce)
{
_logger.LogTrace("GetOrCreate creation cancelled; removing entry {taskId}.", tcs.Task.Id);
// Remove the cache entry. This will throw ObjectDisposedException if the cache entry has already been removed and AggregateException if any cancellation callbacks throw.
try { cts.Cancel(); } catch (Exception) { }
// Propagate the cancellation to any listeners.
if (oce.CancellationToken.IsCancellationRequested)
tcs.TrySetCanceled(oce.CancellationToken);
else
tcs.TrySetCanceled();
}
catch (Exception ex)
{
_logger.LogTrace("GetOrCreate creation failed; removing entry {taskId}.", tcs.Task.Id);
// Remove the cache entry. This will throw ObjectDisposedException if the cache entry has already been removed and AggregateException if any cancellation callbacks throw.
try { cts.Cancel(); } catch (Exception) { }
// Propagate the exception to any listeners.
tcs.TrySetException(ex);
}
}
private sealed class SafeCacheEntry : ICacheEntry
{
private readonly ICacheEntry _cacheEntryImplementation;
private SafeCacheEntry(ICacheEntry cacheEntryImplementation) => _cacheEntryImplementation = cacheEntryImplementation;
public static ICacheEntry Create(IMemoryCache cache, object key)
{
return AsyncCreateEntry().GetAwaiter().GetResult();
#pragma warning disable 1998
async Task<ICacheEntry> AsyncCreateEntry() => new SafeCacheEntry(cache.CreateEntry(key));
#pragma warning restore 1998
}
public void Dispose()
{
AsyncDispose().GetAwaiter().GetResult();
#pragma warning disable 1998
async Task AsyncDispose() => _cacheEntryImplementation.Dispose();
#pragma warning restore 1998
}
public object Key => _cacheEntryImplementation.Key;
public object Value
{
get => _cacheEntryImplementation.Value;
set => _cacheEntryImplementation.Value = value;
}
public DateTimeOffset? AbsoluteExpiration
{
get => _cacheEntryImplementation.AbsoluteExpiration;
set => _cacheEntryImplementation.AbsoluteExpiration = value;
}
public TimeSpan? AbsoluteExpirationRelativeToNow
{
get => _cacheEntryImplementation.AbsoluteExpirationRelativeToNow;
set => _cacheEntryImplementation.AbsoluteExpirationRelativeToNow = value;
}
public TimeSpan? SlidingExpiration
{
get => _cacheEntryImplementation.SlidingExpiration;
set => _cacheEntryImplementation.SlidingExpiration = value;
}
public IList<IChangeToken> ExpirationTokens => _cacheEntryImplementation.ExpirationTokens;
public IList<PostEvictionCallbackRegistration> PostEvictionCallbacks => _cacheEntryImplementation.PostEvictionCallbacks;
public CacheItemPriority Priority
{
get => _cacheEntryImplementation.Priority;
set => _cacheEntryImplementation.Priority = value;
}
public long? Size
{
get => _cacheEntryImplementation.Size;
set => _cacheEntryImplementation.Size = value;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment