Skip to content

Instantly share code, notes, and snippets.

@StephenCleary
Last active July 5, 2024 07:43
Show Gist options
  • Save StephenCleary/2d63729144bc3dbc1746e90d42a61861 to your computer and use it in GitHub Desktop.
Save StephenCleary/2d63729144bc3dbc1746e90d42a61861 to your computer and use it in GitHub Desktop.
Helper methods for working with interlocked state
using System;
using System.Threading;
/// <summary>
/// Interlocked helper methods.
/// </summary>
public static class InterlockedState
{
/// <summary>
/// Executes a state transition from one state to another.
/// </summary>
/// <typeparam name="T">The type of the state; this is generally an immutable type. Strongly consider using a record class.</typeparam>
/// <param name="state">The location of the state.</param>
/// <param name="transformation">The transformation to apply to the state. This may be invoked any number of times and should be a pure function.</param>
/// <returns>The old state and the new state.</returns>
public static (T OldState, T NewState) Transform<T>(ref T state, Func<T, T> transformation)
where T : class? =>
DoTransform<T, RefImpl<T>>(ref state, transformation);
/// <summary>
/// Executes a state transition from one state to another.
/// </summary>
/// <param name="state">The location of the state.</param>
/// <param name="transformation">The transformation to apply to the state. This may be invoked any number of times and should be a pure function.</param>
/// <returns>The old state and the new state.</returns>
public static (int OldState, int NewState) Transform(ref int state, Func<int, int> transformation) =>
DoTransform<int, Int32Impl>(ref state, transformation);
/// <inheritdoc cref="Transform(ref int, Func{int, int})"/>
public static (uint OldState, uint NewState) Transform(ref uint state, Func<uint, uint> transformation) =>
DoTransform<uint, UInt32Impl>(ref state, transformation);
/// <inheritdoc cref="Transform(ref int, Func{int, int})"/>
public static (long OldState, long NewState) Transform(ref long state, Func<long, long> transformation) =>
DoTransform<long, Int64Impl>(ref state, transformation);
/// <inheritdoc cref="Transform(ref int, Func{int, int})"/>
public static (ulong OldState, ulong NewState) Transform(ref ulong state, Func<ulong, ulong> transformation) =>
DoTransform<ulong, UInt64Impl>(ref state, transformation);
/// <summary>
/// Reads the current state. Note that the state may have changed by the time this method returns.
/// </summary>
/// <typeparam name="T">The type of the state; this is generally an immutable type. Strongly consider using a record class.</typeparam>
/// <param name="state">The location of the state.</param>
/// <returns>The current state.</returns>
public static T Read<T>(ref T state)
where T : class? =>
DoRead<T, RefImpl<T>>(ref state);
/// <summary>
/// Reads the current state. Note that the state may have changed by the time this method returns.
/// </summary>
/// <param name="state">The location of the state.</param>
/// <returns>The current state.</returns>
public static int Read(ref int state) => DoRead<int, Int32Impl>(ref state);
/// <inheritdoc cref="Read(ref int)"/>
public static uint Read(ref uint state) => DoRead<uint, UInt32Impl>(ref state);
/// <inheritdoc cref="Read(ref int)"/>
public static long Read(ref long state) => DoRead<long, Int64Impl>(ref state);
/// <inheritdoc cref="Read(ref int)"/>
public static ulong Read(ref ulong state) => DoRead<ulong, UInt64Impl>(ref state);
private static (T OldState, T NewState) DoTransform<T, TImpl>(ref T state, Func<T, T> transformation)
where TImpl : IInterlockedCompareExchangeable<T>
{
_ = transformation ?? throw new ArgumentNullException(nameof(transformation));
while (true)
{
var oldState = TImpl.InterlockedCompareExchange(ref state, default!, default!);
var newState = transformation(oldState);
if (TImpl.Equals(TImpl.InterlockedCompareExchange(ref state, newState, oldState), oldState))
return (oldState, newState);
}
}
private static T DoRead<T, TImpl>(ref T state)
where TImpl : IInterlockedCompareExchangeable<T>
{
return TImpl.InterlockedCompareExchange(ref state, default!, default!);
}
private interface IInterlockedCompareExchangeable<T>
{
static abstract T InterlockedCompareExchange(ref T location, T value, T comparand);
static abstract bool Equals(T left, T right);
}
private readonly struct RefImpl<T> : IInterlockedCompareExchangeable<T>
where T : class?
{
public static T InterlockedCompareExchange(ref T location, T value, T comparand) => Interlocked.CompareExchange(ref location, value, comparand);
public static bool Equals(T left, T right) => ReferenceEquals(left, right);
}
private readonly struct Int32Impl : IInterlockedCompareExchangeable<int>
{
public static int InterlockedCompareExchange(ref int location, int value, int comparand) => Interlocked.CompareExchange(ref location, value, comparand);
public static bool Equals(int left, int right) => left == right;
}
private readonly struct Int64Impl : IInterlockedCompareExchangeable<long>
{
public static long InterlockedCompareExchange(ref long location, long value, long comparand) => Interlocked.CompareExchange(ref location, value, comparand);
public static bool Equals(long left, long right) => left == right;
}
private readonly struct UInt32Impl : IInterlockedCompareExchangeable<uint>
{
public static uint InterlockedCompareExchange(ref uint location, uint value, uint comparand) => Interlocked.CompareExchange(ref location, value, comparand);
public static bool Equals(uint left, uint right) => left == right;
}
private readonly struct UInt64Impl : IInterlockedCompareExchangeable<ulong>
{
public static ulong InterlockedCompareExchange(ref ulong location, ulong value, ulong comparand) => Interlocked.CompareExchange(ref location, value, comparand);
public static bool Equals(ulong left, ulong right) => left == right;
}
}
// InterlockedState usage example: a simple reference counter.
// State consists of two parts: a counter and the referenced object.
// The reference count "latches" at 0; increments or decrements fail after it reaches 0.
/// <summary>
/// A reference count for an underlying target.
/// </summary>
public sealed class ReferenceCounter
{
private sealed record class State(int Count, object? Target);
private State _state;
/// <summary>
/// Creates a new reference counter with a reference count of 1 referencing the specified target.
/// </summary>
public ReferenceCounter(object? target) => _state = new(Count: 1, Target: target);
/// <summary>
/// Attempts to increment the reference count. Returns <c>false</c> if the reference count already reached zero.
/// </summary>
public bool TryIncrementCount() => InterlockedState.Transform(ref _state, state => state switch
{
(0, _) => state,
(var count, var target) => new(count + 1, target),
}).OldState.Count != 0;
/// <summary>
/// Attempts to decrement the reference count.
/// Returns a non-null target if this decrement is the one that caused the count to reach zero.
/// </summary>
public object? TryDecrementCount()
{
var (oldState, newState) = InterlockedState.Transform(ref _state, state => state switch
{
(0, _) => state,
(1, var target) => new(0, null),
(var count, var target) => new(count - 1, target),
});
if (oldState.Target != null && newState.Target == null)
return oldState.Target;
return null;
}
/// <summary>
/// Attempts to retrieve the current target.
/// Returns <c>null</c> if the reference count has reached zero.
/// This is for advanced usage only; the reference count may reach zero by the time this function returns a non-<c>null</c> value.
/// </summary>
public object? TryGetTarget() => InterlockedState.Read(ref _state).Target;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment