Last active
July 5, 2024 07:43
-
-
Save StephenCleary/2d63729144bc3dbc1746e90d42a61861 to your computer and use it in GitHub Desktop.
Helper methods for working with interlocked state
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Threading; | |
/// <summary> | |
/// Interlocked helper methods. | |
/// </summary> | |
public static class InterlockedState | |
{ | |
/// <summary> | |
/// Executes a state transition from one state to another. | |
/// </summary> | |
/// <typeparam name="T">The type of the state; this is generally an immutable type. Strongly consider using a record class.</typeparam> | |
/// <param name="state">The location of the state.</param> | |
/// <param name="transformation">The transformation to apply to the state. This may be invoked any number of times and should be a pure function.</param> | |
/// <returns>The old state and the new state.</returns> | |
public static (T OldState, T NewState) Transform<T>(ref T state, Func<T, T> transformation) | |
where T : class? => | |
DoTransform<T, RefImpl<T>>(ref state, transformation); | |
/// <summary> | |
/// Executes a state transition from one state to another. | |
/// </summary> | |
/// <param name="state">The location of the state.</param> | |
/// <param name="transformation">The transformation to apply to the state. This may be invoked any number of times and should be a pure function.</param> | |
/// <returns>The old state and the new state.</returns> | |
public static (int OldState, int NewState) Transform(ref int state, Func<int, int> transformation) => | |
DoTransform<int, Int32Impl>(ref state, transformation); | |
/// <inheritdoc cref="Transform(ref int, Func{int, int})"/> | |
public static (uint OldState, uint NewState) Transform(ref uint state, Func<uint, uint> transformation) => | |
DoTransform<uint, UInt32Impl>(ref state, transformation); | |
/// <inheritdoc cref="Transform(ref int, Func{int, int})"/> | |
public static (long OldState, long NewState) Transform(ref long state, Func<long, long> transformation) => | |
DoTransform<long, Int64Impl>(ref state, transformation); | |
/// <inheritdoc cref="Transform(ref int, Func{int, int})"/> | |
public static (ulong OldState, ulong NewState) Transform(ref ulong state, Func<ulong, ulong> transformation) => | |
DoTransform<ulong, UInt64Impl>(ref state, transformation); | |
/// <summary> | |
/// Reads the current state. Note that the state may have changed by the time this method returns. | |
/// </summary> | |
/// <typeparam name="T">The type of the state; this is generally an immutable type. Strongly consider using a record class.</typeparam> | |
/// <param name="state">The location of the state.</param> | |
/// <returns>The current state.</returns> | |
public static T Read<T>(ref T state) | |
where T : class? => | |
DoRead<T, RefImpl<T>>(ref state); | |
/// <summary> | |
/// Reads the current state. Note that the state may have changed by the time this method returns. | |
/// </summary> | |
/// <param name="state">The location of the state.</param> | |
/// <returns>The current state.</returns> | |
public static int Read(ref int state) => DoRead<int, Int32Impl>(ref state); | |
/// <inheritdoc cref="Read(ref int)"/> | |
public static uint Read(ref uint state) => DoRead<uint, UInt32Impl>(ref state); | |
/// <inheritdoc cref="Read(ref int)"/> | |
public static long Read(ref long state) => DoRead<long, Int64Impl>(ref state); | |
/// <inheritdoc cref="Read(ref int)"/> | |
public static ulong Read(ref ulong state) => DoRead<ulong, UInt64Impl>(ref state); | |
private static (T OldState, T NewState) DoTransform<T, TImpl>(ref T state, Func<T, T> transformation) | |
where TImpl : IInterlockedCompareExchangeable<T> | |
{ | |
_ = transformation ?? throw new ArgumentNullException(nameof(transformation)); | |
while (true) | |
{ | |
var oldState = TImpl.InterlockedCompareExchange(ref state, default!, default!); | |
var newState = transformation(oldState); | |
if (TImpl.Equals(TImpl.InterlockedCompareExchange(ref state, newState, oldState), oldState)) | |
return (oldState, newState); | |
} | |
} | |
private static T DoRead<T, TImpl>(ref T state) | |
where TImpl : IInterlockedCompareExchangeable<T> | |
{ | |
return TImpl.InterlockedCompareExchange(ref state, default!, default!); | |
} | |
private interface IInterlockedCompareExchangeable<T> | |
{ | |
static abstract T InterlockedCompareExchange(ref T location, T value, T comparand); | |
static abstract bool Equals(T left, T right); | |
} | |
private readonly struct RefImpl<T> : IInterlockedCompareExchangeable<T> | |
where T : class? | |
{ | |
public static T InterlockedCompareExchange(ref T location, T value, T comparand) => Interlocked.CompareExchange(ref location, value, comparand); | |
public static bool Equals(T left, T right) => ReferenceEquals(left, right); | |
} | |
private readonly struct Int32Impl : IInterlockedCompareExchangeable<int> | |
{ | |
public static int InterlockedCompareExchange(ref int location, int value, int comparand) => Interlocked.CompareExchange(ref location, value, comparand); | |
public static bool Equals(int left, int right) => left == right; | |
} | |
private readonly struct Int64Impl : IInterlockedCompareExchangeable<long> | |
{ | |
public static long InterlockedCompareExchange(ref long location, long value, long comparand) => Interlocked.CompareExchange(ref location, value, comparand); | |
public static bool Equals(long left, long right) => left == right; | |
} | |
private readonly struct UInt32Impl : IInterlockedCompareExchangeable<uint> | |
{ | |
public static uint InterlockedCompareExchange(ref uint location, uint value, uint comparand) => Interlocked.CompareExchange(ref location, value, comparand); | |
public static bool Equals(uint left, uint right) => left == right; | |
} | |
private readonly struct UInt64Impl : IInterlockedCompareExchangeable<ulong> | |
{ | |
public static ulong InterlockedCompareExchange(ref ulong location, ulong value, ulong comparand) => Interlocked.CompareExchange(ref location, value, comparand); | |
public static bool Equals(ulong left, ulong right) => left == right; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// InterlockedState usage example: a simple reference counter. | |
// State consists of two parts: a counter and the referenced object. | |
// The reference count "latches" at 0; increments or decrements fail after it reaches 0. | |
/// <summary> | |
/// A reference count for an underlying target. | |
/// </summary> | |
public sealed class ReferenceCounter | |
{ | |
private sealed record class State(int Count, object? Target); | |
private State _state; | |
/// <summary> | |
/// Creates a new reference counter with a reference count of 1 referencing the specified target. | |
/// </summary> | |
public ReferenceCounter(object? target) => _state = new(Count: 1, Target: target); | |
/// <summary> | |
/// Attempts to increment the reference count. Returns <c>false</c> if the reference count already reached zero. | |
/// </summary> | |
public bool TryIncrementCount() => InterlockedState.Transform(ref _state, state => state switch | |
{ | |
(0, _) => state, | |
(var count, var target) => new(count + 1, target), | |
}).OldState.Count != 0; | |
/// <summary> | |
/// Attempts to decrement the reference count. | |
/// Returns a non-null target if this decrement is the one that caused the count to reach zero. | |
/// </summary> | |
public object? TryDecrementCount() | |
{ | |
var (oldState, newState) = InterlockedState.Transform(ref _state, state => state switch | |
{ | |
(0, _) => state, | |
(1, var target) => new(0, null), | |
(var count, var target) => new(count - 1, target), | |
}); | |
if (oldState.Target != null && newState.Target == null) | |
return oldState.Target; | |
return null; | |
} | |
/// <summary> | |
/// Attempts to retrieve the current target. | |
/// Returns <c>null</c> if the reference count has reached zero. | |
/// This is for advanced usage only; the reference count may reach zero by the time this function returns a non-<c>null</c> value. | |
/// </summary> | |
public object? TryGetTarget() => InterlockedState.Read(ref _state).Target; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment