Created
September 11, 2020 08:48
-
-
Save afscrome/f8b5a8c641f74d15da4eb6860a077b16 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class BackgroundRefreshableCache<TKey, TValue> | |
{ | |
public TimeSpan RefreshInterval { get; } | |
private readonly Func<TKey, Task<TValue>> _loader; | |
private ImmutableDictionary<TKey, Task<TValue>> _cache = ImmutableDictionary<TKey, Task<TValue>>.Empty; | |
public BackgroundRefreshableCache(Func<TKey, Task<TValue>> loader) | |
{ | |
_loader = loader; | |
} | |
public Task<TValue> GetAsync(TKey key) | |
{ | |
var itemTask = ImmutableInterlocked.GetOrAdd(ref _cache, key, _loader); | |
// If the task is running, make sure to queue the task's removal if it faults later on | |
if (!itemTask.IsCompleted) | |
{ | |
itemTask.ContinueWith(_ => | |
{ | |
ImmutableInterlocked.TryRemove(ref _cache, key, out var _); | |
}, TaskContinuationOptions.NotOnRanToCompletion); | |
} | |
// If the task has completed but is in the faulted state, clear it for next time | |
else if (itemTask.Status != TaskStatus.RanToCompletion) | |
{ | |
ImmutableInterlocked.TryRemove(ref _cache, key, out var _); | |
} | |
return itemTask; | |
//TODO: Automate refreshing. Rather than refreshing each one individualy | |
//I'd have a task that runs every 5 mins that refreshes everything one by one before | |
//sleeping until next time | |
} | |
public async Task<TValue> Refresh(TKey key) | |
{ | |
var task = _loader(key); | |
await task; | |
//Only re-set the cache if the value has been loaded successfully | |
return await ImmutableInterlocked.AddOrUpdate(ref _cache, key, _ => task, (a, b) => task); | |
} | |
} | |
public class Tests | |
{ | |
[Test] | |
public async Task When_Making_First_Call_To_Value_Then_Blocked_Until_Value_Is_Loaded() | |
{ | |
var tcs = new TaskCompletionSource<int>(); | |
var refreshable = RefreshableLoadingSequence( | |
tcs.Task | |
); | |
var initialValueTask = refreshable.GetAsync("key"); | |
//Ensure value has not been loaded prematurely | |
await Task.Delay(TimeSpan.FromMilliseconds(1)); | |
Assert.False(initialValueTask.IsCompleted); | |
//Complete the load | |
tcs.SetResult(45); | |
//Verify value is returned | |
Assert.That(await refreshable.GetAsync("key"), Is.EqualTo(45)); | |
} | |
[Test] | |
public async Task When_Making_Multiple_Calls_To_Value_Before_Initial_Load_Then_Value_Only_Loaded_Once() | |
{ | |
var tcs = new TaskCompletionSource<string>(); | |
var refreshable = RefreshableLoadingSequence( | |
tcs.Task, | |
Task.FromResult("Now you don't") | |
); | |
var consumer1 = refreshable.GetAsync("key"); | |
var consumer2 = refreshable.GetAsync("key"); | |
var consumer3 = refreshable.GetAsync("key"); | |
//Ensure values have not been loaded prematurely | |
await Task.Delay(TimeSpan.FromMilliseconds(1)); | |
Assert.False(consumer1.IsCompleted); | |
Assert.False(consumer2.IsCompleted); | |
Assert.False(consumer3.IsCompleted); | |
//Complete the load | |
tcs.SetResult("Now you see me"); | |
//Verify all consumers got expected value | |
Assert.That(await consumer1, Is.EqualTo("Now you see me")); | |
Assert.That(await consumer2, Is.EqualTo("Now you see me")); | |
Assert.That(await consumer3, Is.EqualTo("Now you see me")); | |
} | |
[Test] | |
public async Task When_Getting_Value_Then_Value_Is_Cached() | |
{ | |
var refreshable = RefreshableLoadingSequence( | |
Task.FromResult("Alice"), | |
Task.FromResult("Bob") | |
); | |
Assert.That(await refreshable.GetAsync("key"), Is.EqualTo("Alice")); | |
Assert.That(await refreshable.GetAsync("key"), Is.EqualTo("Alice")); | |
Assert.That(await refreshable.GetAsync("key"), Is.EqualTo("Alice")); | |
Assert.That(await refreshable.GetAsync("key"), Is.EqualTo("Alice")); | |
} | |
[Test] | |
public async Task When_Refresh_Succeeds_Then_Value_Is_Updated() | |
{ | |
var refreshable = RefreshableLoadingSequence( | |
Task.FromResult("first"), | |
Task.FromResult("second") | |
); | |
Assert.That(await refreshable.GetAsync("key"), Is.EqualTo("first")); | |
var refresh = await refreshable.Refresh("key"); | |
Assert.That(refresh, Is.EqualTo("second")); | |
Assert.That(await refreshable.GetAsync("key"), Is.EqualTo("second")); | |
} | |
[Test] | |
public async Task When_Refreshing_Then_Previous_Value_Returned_Until_Refresh_Completes() | |
{ | |
var tcs = new TaskCompletionSource<double>(); | |
var refreshable = RefreshableLoadingSequence( | |
Task.FromResult(3.142), | |
tcs.Task | |
); | |
//Initial load | |
Assert.That(await refreshable.GetAsync("key"), Is.EqualTo(3.142)); | |
//Start Refresh | |
var refreshTask = refreshable.Refresh("key"); | |
//Ensure old value is being returned | |
await Task.Delay(TimeSpan.FromMilliseconds(1)); | |
var initialValueTask = refreshable.GetAsync("key"); | |
Assert.True(initialValueTask.IsCompleted); | |
Assert.That(initialValueTask.Result, Is.EqualTo(3.142)); | |
//Complete Refresh | |
tcs.SetResult(2.718); | |
Assert.That(await refreshTask, Is.EqualTo(2.718)); | |
//Value now returns updated result | |
var postRefreshValueTask = refreshable.GetAsync("key"); | |
Assert.True(postRefreshValueTask.IsCompleted); | |
Assert.That(postRefreshValueTask.Result, Is.EqualTo(2.718)); | |
} | |
[Test] | |
public async Task When_Refresh_Throws_Exception_Then_Old_Value_Persists() | |
{ | |
var exception = new Exception("Simulated failure"); | |
var refreshable = RefreshableLoadingSequence( | |
Task.FromResult('%'), | |
Task.FromException<char>(exception) | |
); | |
//Initial load | |
Assert.That(await refreshable.GetAsync("key"), Is.EqualTo('%')); | |
//Ensure Refresh throws expected exception | |
var caughtException = Assert.ThrowsAsync<Exception>(() => refreshable.Refresh("key")); | |
Assert.That(caughtException, Is.EqualTo(exception)); | |
//But Value still returns previous result | |
var postRefreshValueTask = refreshable.GetAsync("key"); | |
Assert.True(postRefreshValueTask.IsCompleted); | |
Assert.That(postRefreshValueTask.Result, Is.EqualTo('%')); | |
} | |
[Test] | |
public async Task When_Refresh_Is_Cancelled_Then_Old_Value_Persists() | |
{ | |
var cancellationSource = new CancellationTokenSource(); | |
cancellationSource.Cancel(); | |
var refreshable = RefreshableLoadingSequence( | |
Task.FromResult('%'), | |
Task.FromCanceled<char>(cancellationSource.Token) | |
); | |
//Initial load | |
Assert.That(await refreshable.GetAsync("key"), Is.EqualTo('%')); | |
//Ensure Refresh is canceleld | |
Assert.ThrowsAsync<TaskCanceledException>(() => refreshable.Refresh("key")); | |
//But Value still returns previous result | |
var postRefreshValueTask = refreshable.GetAsync("key"); | |
Assert.True(postRefreshValueTask.IsCompleted); | |
Assert.That(postRefreshValueTask.Result, Is.EqualTo('%')); | |
} | |
[Test] | |
public async Task When_Initial_Load_Throws_Exception_Then_Next_Call_To_Value_Retries() | |
{ | |
var initialException = new Exception("Initial Exception"); | |
var refreshable = RefreshableLoadingSequence( | |
Task.FromException<string>(initialException), | |
Task.FromResult("I'm alive") | |
); | |
//First Load | |
var caughtInitialValueException = Assert.ThrowsAsync<Exception>(() => refreshable.GetAsync("key")); | |
Assert.That(caughtInitialValueException, Is.EqualTo(initialException)); | |
//Second Load | |
Assert.That(await refreshable.GetAsync("key"), Is.EqualTo("I'm alive")); | |
} | |
[Test] | |
public async Task When_Initial_Load_Is_Cancelled_Then_Next_Call_To_Value_Retries() | |
{ | |
var cancellationSource = new CancellationTokenSource(); | |
cancellationSource.Cancel(); | |
var refreshable = RefreshableLoadingSequence( | |
Task.FromCanceled<decimal>(cancellationSource.Token), | |
Task.FromResult(93.42m) | |
); | |
//First Load | |
Assert.ThrowsAsync<TaskCanceledException>(() => refreshable.GetAsync("key")); | |
//Second Load | |
Assert.That(await refreshable.GetAsync("key"), Is.EqualTo(93.42m)); | |
} | |
private BackgroundRefreshableCache<object, TValue> RefreshableLoadingSequence<TValue>( | |
params Task<TValue>[] sequence) | |
{ | |
int index = 0; | |
return new BackgroundRefreshableCache<object, TValue>(_ => | |
{ | |
if (index >= sequence.Length) | |
throw new IndexOutOfRangeException(); | |
return sequence[index++]; | |
}); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment