-
-
Save StephenCleary/39a2cd0aa3c705a984a4dbbea8275fe9 to your computer and use it in GitHub Desktop.
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Threading; | |
using System.Threading.Tasks; | |
using Microsoft.Extensions.Caching.Memory; | |
using Microsoft.Extensions.DependencyInjection; | |
using Microsoft.Extensions.Logging; | |
using Microsoft.Extensions.Primitives; | |
using Nito.Logging; | |
/// <summary> | |
/// Provides an asynchronous cache with exactly-once creation method semantics and flexible cache entries. | |
/// </summary> | |
public sealed class AsyncCache | |
{ | |
private readonly object _mutex = new(); | |
private readonly IMemoryCache _cache; | |
private readonly ILogger<AsyncCache> _logger; | |
/// <summary> | |
/// Creates an asynchronous cache wrapping an existing memory cache. | |
/// </summary> | |
public AsyncCache(IMemoryCache cache, ILogger<AsyncCache> logger) | |
{ | |
_cache = cache; | |
_logger = logger; | |
} | |
internal static void Register(IServiceCollection services) | |
{ | |
if (services.Any(x => x.ServiceType == typeof(AsyncCache))) | |
return; | |
services.AddMemoryCache(); | |
services.AddLogging(); | |
services.AddSingleton<AsyncCache>(); | |
} | |
/// <summary> | |
/// Removes an item from the cache. | |
/// </summary> | |
/// <param name="key">The key of the item.</param> | |
public void Remove(object key) | |
{ | |
using var _ = _logger.BeginDataScope(new {cacheKey = key}); | |
_logger.LogTrace("Removing entry."); | |
_cache.Remove(key); | |
} | |
/// <summary> | |
/// Removes a specific future from the cache. | |
/// </summary> | |
/// <typeparam name="T">The type of the item.</typeparam> | |
/// <param name="key">The key of the item.</param> | |
/// <param name="value">The future that has to match the entry.</param> | |
public bool TryRemove<T>(object key, Task<T> value) | |
{ | |
_ = value ?? throw new ArgumentNullException(nameof(value)); | |
using var __ = _logger.BeginDataScope(new {cacheKey = key, taskId = value.Id}); | |
lock (_mutex) | |
{ | |
var existingTask = _cache.TryGetValue(key, out TaskCompletionSource<T> tcs) ? tcs.Task : null; | |
if (existingTask != value) | |
{ | |
if (existingTask == null) | |
_logger.LogTrace("Attempted to remove entry, but it was already removed."); | |
else | |
_logger.LogTrace("Attempted to remove entry, but it had already been replaced by {existingTaskId}.", existingTask.Id); | |
return false; | |
} | |
_logger.LogTrace("Removing entry."); | |
_cache.Remove(key); | |
return true; | |
} | |
} | |
/// <summary> | |
/// Attempts to retrieve an item from the cache. | |
/// </summary> | |
/// <typeparam name="T">The type of the item.</typeparam> | |
/// <param name="key">The key of the item.</param> | |
/// <param name="task">On return, contains a future item.</param> | |
public bool TryGet<T>(object key, out Task<T>? task) | |
{ | |
using var _ = _logger.BeginDataScope(new {cacheKey = key}); | |
task = _cache.TryGetValue(key, out TaskCompletionSource<T> tcs) ? tcs.Task : null; | |
if (task == null) | |
_logger.LogTrace("Attempted to retrieve entry, but it was not found."); | |
else | |
_logger.LogTrace("Retrieved entry {taskId}.", task.Id); | |
return task != null; | |
} | |
/// <summary> | |
/// Atomically retrieves or creates a cache item. | |
/// </summary> | |
/// <typeparam name="T">The type of the item.</typeparam> | |
/// <param name="key">The key of the item.</param> | |
/// <param name="create">An asynchronous creation method. This method will only be invoked once. The creation method may control the cache entry behavior for the resulting value by using its <see cref="ICacheEntry"/> parameter; the <see cref="ICacheEntry.Value"/> member is ignored, but all other members are honored.</param> | |
/// <returns>A future item.</returns> | |
public Task<T> GetOrCreateAsync<T>(object key, Func<ICacheEntry, Task<T>> create) | |
{ | |
using var _ = _logger.BeginDataScope(new {cacheKey = key}); | |
TaskCompletionSource<T> tcs; | |
CancellationTokenSource cts; | |
lock (_mutex) | |
{ | |
if (_cache.TryGetValue(key, out tcs)) | |
{ | |
_logger.LogTrace("GetOrCreate found existing entry {taskId}.", tcs.Task.Id); | |
return tcs.Task; | |
} | |
tcs = new TaskCompletionSource<T>(); | |
using var entry = SafeCacheEntry.Create(_cache, key).SetSize(1); | |
cts = new CancellationTokenSource(); | |
entry.Value = tcs; | |
entry.RegisterPostEvictionCallback((_, _, _, _) => cts.Dispose()); | |
entry.AddExpirationToken(new CancellationChangeToken(cts.Token)); | |
_logger.LogTrace("GetOrCreate creating new entry {taskId}.", tcs.Task.Id); | |
} | |
InvokeAndPropagateCompletion(create, SafeCacheEntry.Create(_cache, key).SetSize(1), tcs, cts); | |
return tcs.Task; | |
} | |
/// <summary> | |
/// Invokes the creation method and (asynchronously) updates the cache entry with the results. | |
/// - If the function succeeds synchronously, the cache entry is updated and the TCS completed by the time this method returns. | |
/// - If the function fails synchronously, the cache entry is removed and the TCS faulted by the time this method returns. | |
/// - If the function succeeds asynchronously, the cache entry is updated when the function completes *if* the cache entry has not changed by that time. | |
/// - If the function faults asynchronously, the cache entry is removed when the function completes *if* the cache entry has not changed by that time. | |
/// </summary> | |
/// <typeparam name="T">The type of object created by the <paramref name="create"/> method.</typeparam> | |
/// <param name="create">The creation method, which may update the cache entry set when the creation method completes. The <see cref="ICacheEntry.Value"/> member is ignored, but all other members are honored.</param> | |
/// <param name="cacheEntry">The cache entry that will be used to replace the cache entry currently containing <paramref name="tcs"/> if the creation succeeds.</param> | |
/// <param name="tcs">The task completion source currently in the cache entry. This method will (eventually) complete this task completion source.</param> | |
/// <param name="cts">The cancellation token source used to evict the current cache entry if the creation method fails.</param> | |
private async void InvokeAndPropagateCompletion<T>(Func<ICacheEntry, Task<T>> create, ICacheEntry cacheEntry, TaskCompletionSource<T> tcs, CancellationTokenSource cts) | |
{ | |
try | |
{ | |
// Asynchronously create the value. | |
var result = await create(cacheEntry); | |
// Atomically: | |
// - Check to see if we're still the one in the cache, and | |
// - If we are, update the cache entry with a new one having the same TCS value, but including new expiration and other settings from the creation method. | |
lock (_mutex) | |
{ | |
// This check is necessary to avoid a race condition where our entry has been removed and re-created. | |
// In that case, there will be a cache entry but it will not be our cache entry, so we should not replace it. | |
// Rather, we'll leave the cache as-is (without our entry) and just complete our listeners (via TrySetResult, below). | |
if (_cache.TryGetValue(cacheEntry.Key, out TaskCompletionSource<T> existingTcs) && existingTcs == tcs) | |
{ | |
_logger.LogTrace("GetOrCreate creation completed successfully; updating entry {taskId}.", tcs.Task.Id); | |
using (cacheEntry) | |
cacheEntry.Value = tcs; | |
} | |
else | |
{ | |
if (existingTcs == null) | |
_logger.LogTrace("GetOrCreate creation completed successfully, but entry {taskId} has been removed.", tcs.Task.Id); | |
else | |
_logger.LogTrace("GetOrCreate creation completed successfully, but entry {taskId} has been replaced by entry {replacementTaskId}.", tcs.Task.Id, existingTcs.Task.Id); | |
} | |
} | |
// Propagate the result to any listeners. | |
tcs.TrySetResult(result); | |
} | |
catch (OperationCanceledException oce) | |
{ | |
_logger.LogTrace("GetOrCreate creation cancelled; removing entry {taskId}.", tcs.Task.Id); | |
// Remove the cache entry. This will throw ObjectDisposedException if the cache entry has already been removed and AggregateException if any cancellation callbacks throw. | |
try { cts.Cancel(); } catch (Exception) { } | |
// Propagate the cancellation to any listeners. | |
if (oce.CancellationToken.IsCancellationRequested) | |
tcs.TrySetCanceled(oce.CancellationToken); | |
else | |
tcs.TrySetCanceled(); | |
} | |
catch (Exception ex) | |
{ | |
_logger.LogTrace("GetOrCreate creation failed; removing entry {taskId}.", tcs.Task.Id); | |
// Remove the cache entry. This will throw ObjectDisposedException if the cache entry has already been removed and AggregateException if any cancellation callbacks throw. | |
try { cts.Cancel(); } catch (Exception) { } | |
// Propagate the exception to any listeners. | |
tcs.TrySetException(ex); | |
} | |
} | |
private sealed class SafeCacheEntry : ICacheEntry | |
{ | |
private readonly ICacheEntry _cacheEntryImplementation; | |
private SafeCacheEntry(ICacheEntry cacheEntryImplementation) => _cacheEntryImplementation = cacheEntryImplementation; | |
public static ICacheEntry Create(IMemoryCache cache, object key) | |
{ | |
return AsyncCreateEntry().GetAwaiter().GetResult(); | |
#pragma warning disable 1998 | |
async Task<ICacheEntry> AsyncCreateEntry() => new SafeCacheEntry(cache.CreateEntry(key)); | |
#pragma warning restore 1998 | |
} | |
public void Dispose() | |
{ | |
AsyncDispose().GetAwaiter().GetResult(); | |
#pragma warning disable 1998 | |
async Task AsyncDispose() => _cacheEntryImplementation.Dispose(); | |
#pragma warning restore 1998 | |
} | |
public object Key => _cacheEntryImplementation.Key; | |
public object Value | |
{ | |
get => _cacheEntryImplementation.Value; | |
set => _cacheEntryImplementation.Value = value; | |
} | |
public DateTimeOffset? AbsoluteExpiration | |
{ | |
get => _cacheEntryImplementation.AbsoluteExpiration; | |
set => _cacheEntryImplementation.AbsoluteExpiration = value; | |
} | |
public TimeSpan? AbsoluteExpirationRelativeToNow | |
{ | |
get => _cacheEntryImplementation.AbsoluteExpirationRelativeToNow; | |
set => _cacheEntryImplementation.AbsoluteExpirationRelativeToNow = value; | |
} | |
public TimeSpan? SlidingExpiration | |
{ | |
get => _cacheEntryImplementation.SlidingExpiration; | |
set => _cacheEntryImplementation.SlidingExpiration = value; | |
} | |
public IList<IChangeToken> ExpirationTokens => _cacheEntryImplementation.ExpirationTokens; | |
public IList<PostEvictionCallbackRegistration> PostEvictionCallbacks => _cacheEntryImplementation.PostEvictionCallbacks; | |
public CacheItemPriority Priority | |
{ | |
get => _cacheEntryImplementation.Priority; | |
set => _cacheEntryImplementation.Priority = value; | |
} | |
public long? Size | |
{ | |
get => _cacheEntryImplementation.Size; | |
set => _cacheEntryImplementation.Size = value; | |
} | |
} | |
} |
Would be ideal if the async methods could return ValueTask<T>
instead of Task<T>
, but I guess that's not possible as-is because of the user of TaskCompletionSource
?
@cocowalla I don't see the need for ValueTask<T>
here. GetOrCreateAsync
returns a Task<T>
, but it's a cached instance; it's not creating a new one on each call.
It may be possible to have the create
delegate return a ValueTask<T>
, but I'm not sure how much that would improve performance-wise. It would only be called once per cached value.
try { cts.Cancel(); } catch (ObjectDisposedException) { }
This line makes me nervous. The Cancel
might throw an AggregateException
containing all the exceptions thrown by the registered callbacks on the associated CancellationToken
. In case this happens the tcs
will not get completed, and most likely a deadlock will occur.
@theodorzoulias Thanks! I believe you're correct, and that this can actually happen with a custom memory cache underlying this one. I'll change it to catch Exception
instead since we always want to continue.
Thanks! I do have a bunch of unit tests (just not included in this gist); I'll add this one, too.