Skip to content

Instantly share code, notes, and snippets.

@koturn
Last active May 20, 2025 21:26
Show Gist options
  • Save koturn/6d4ed98774d94d2a786ffe5628b95d49 to your computer and use it in GitHub Desktop.
Save koturn/6d4ed98774d94d2a786ffe5628b95d49 to your computer and use it in GitHub Desktop.
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
#if !NET9_0_OR_GREATER
using Lock = object;
#endif // !NET9_0_OR_GREATER
namespace Koturn.Tasks
{
/// <summary>
/// Provides a task scheduler that ensures a maximum concurrency level while running on top of the thread pool.
/// </summary>
/// <remarks>
/// <seealso href="https://learn.microsoft.com/en-us/dotnet/api/system.threading.tasks.taskscheduler?view=net-9.0"/>
/// </remarks>
public sealed class LimitedConcurrencyLevelTaskScheduler : TaskScheduler
{
/// <summary>
/// Indicates whether the current thread is processing work items.
/// </summary>
[ThreadStatic]
private static bool _currentThreadIsProcessingItems;
/// <summary>
/// The maximum concurrency level allowed by this scheduler.
/// </summary>
public sealed override int MaximumConcurrencyLevel => _maxDegreeOfParallelism;
/// <summary>
/// <para>An alias of <see cref="MaximumConcurrencyLevel"/>.</para>
/// <para>This property can be changed while it cannot be at <see cref="MaximumConcurrencyLevel"/>.</para>
/// </summary>
public int MaxDegreeOfParallelism
{
get => _maxDegreeOfParallelism;
set
{
#if NET8_0_OR_GREATER
ArgumentOutOfRangeException.ThrowIfLessThan(value, 1);
#else
ThrowIfLessThan(maxDegreeOfParallelism, 1);
#endif // NET8_0_OR_GREATER
var maxDegreeOfParallelismOld = _maxDegreeOfParallelism;
_maxDegreeOfParallelism = value;
if (maxDegreeOfParallelismOld >= value)
{
return;
}
lock (_taskListLock)
{
var diff = Math.Min(value, _tasks.Count) - _delegatesQueuedOrRunning;
if (diff > 0)
{
_delegatesQueuedOrRunning += diff;
for (int i = 0; i < diff; i++)
{
NotifyThreadPoolOfPendingWork();
}
}
}
}
}
/// <summary>
/// The list of tasks to be executed.
/// </summary>
private readonly LinkedList<Task> _tasks = new();
/// <summary>
/// <para>Lock object for <see cref="_tasks"/> and <see cref="_delegatesQueuedOrRunning"/>.</para>
/// <para>This variable locked by <see cref="_taskListLock"/>.</para>
/// </summary>
private readonly Lock _taskListLock = new();
/// <summary>
/// The maximum concurrency level allowed by this scheduler.
/// </summary>
private int _maxDegreeOfParallelism;
/// <summary>
/// <para>Indicates whether the scheduler is currently processing work items.</para>
/// <para>This variable locked by <see cref="_taskListLock"/>.</para>
/// </summary>
private int _delegatesQueuedOrRunning = 0;
/// <summary>
/// Creates a new instance with the specified degree of parallelism.
/// </summary>
/// <param name="maxDegreeOfParallelism"></param>
public LimitedConcurrencyLevelTaskScheduler(int maxDegreeOfParallelism)
: base()
{
_maxDegreeOfParallelism = maxDegreeOfParallelism;
}
/// <summary>
/// Queues a task to the scheduler.
/// </summary>
/// <param name="task">A task.</param>
protected sealed override void QueueTask(Task task)
{
// Add the task to the list of tasks to be processed. If there aren't enough
// delegates currently queued or running to process tasks, schedule another.
lock (_taskListLock)
{
_tasks.AddLast(task);
if (_delegatesQueuedOrRunning < _maxDegreeOfParallelism)
{
_delegatesQueuedOrRunning++;
NotifyThreadPoolOfPendingWork();
}
}
}
/// <summary>
/// Inform the ThreadPool that there's work to be executed for this scheduler.
/// </summary>
private void NotifyThreadPoolOfPendingWork()
{
ThreadPool.UnsafeQueueUserWorkItem(_ =>
{
// Note that the current thread is now processing work items.
// This is necessary to enable inlining of tasks into this thread.
_currentThreadIsProcessingItems = true;
try
{
// Process all available items in the queue.
while (true)
{
Task item;
lock (_taskListLock)
{
if (_delegatesQueuedOrRunning > _maxDegreeOfParallelism)
{
_delegatesQueuedOrRunning--;
break;
}
// When there are no more items to be processed,
// note that we're done processing, and get out.
var task = _tasks.First;
if (task is null)
{
_delegatesQueuedOrRunning--;
break;
}
// Get the next item from the queue
item = task.Value;
_tasks.RemoveFirst();
}
// Execute the task we pulled out of the queue
TryExecuteTask(item);
}
}
// We're done processing items on the current thread
finally
{
_currentThreadIsProcessingItems = false;
}
}, null);
}
/// <summary>
/// Attempts to execute the specified task on the current thread.
/// </summary>
/// <param name="task">A task to execute.</param>
/// <param name="taskWasPreviouslyQueued">A flag whether <paramref name="task"/> is queued previously.</param>
/// <returns><c>true</c> if task was successfully executed, <c>false</c> if it was not.</returns>
protected sealed override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
{
// If this thread isn't already processing a task, we don't support inlining
// If the task was previously queued, remove it from the queue
return _currentThreadIsProcessingItems
&& (!taskWasPreviouslyQueued || TryDequeue(task))
&& TryExecuteTask(task);
}
/// <summary>
/// Attempt to remove a previously scheduled task from the scheduler.
/// </summary>
/// <param name="task">A task to remove from the scheduler</param>
/// <returns><c>true</c> if the element containing value is successfully removed; otherwise, <c>false</c>.</returns>
protected sealed override bool TryDequeue(Task task)
{
lock (_taskListLock)
{
return _tasks.Remove(task);
}
}
/// <summary>
/// Gets an enumerable of the tasks currently scheduled on this scheduler.
/// </summary>
/// <returns>An enumerator of tasks.</returns>
protected sealed override IEnumerable<Task> GetScheduledTasks()
{
#if NET9_0_OR_GREATER
var lockTaken = false;
try
{
lockTaken = _taskListLock.TryEnter();
if (!lockTaken)
{
throw new NotSupportedException();
}
return _tasks;
}
finally
{
if (lockTaken)
{
_taskListLock.Exit();
}
}
#else
var lockTaken = false;
try
{
Monitor.TryEnter(_taskListLock, ref lockTaken);
if (!lockTaken)
{
throw new NotSupportedException();
}
return _tasks;
}
finally
{
if (lockTaken)
{
Monitor.Exit(_taskListLock);
}
}
#endif // NET9_0_OR_GREATER
}
#if !NET8_0_OR_GREATER
/// <summary>
/// Throw <see cref="ArgumentOutOfRangeException"/>.
/// </summary>
/// <typeparam name="T">The type of the objects.</typeparam>
/// <param name="value">The value of the argument that causes this exception.</param>
/// <param name="other">The value to compare with <paramref name="value"/>.</param>
/// <param name="paramName">The name of the parameter with which <paramref name="value"/> corresponds.</param>
/// <exception cref="ArgumentOutOfRangeException">Always thrown.</exception>
[DoesNotReturn]
private static void ThrowLess<T>(T value, T other, string? paramName)
{
throw new ArgumentOutOfRangeException(paramName, value, $"'{value}' must be greater than or equal to '{other}'.");
}
/// <summary>
/// Throws an <see cref="ArgumentOutOfRangeException"/> if <paramref name="value"/> is less than <paramref name="other"/>.
/// </summary>
/// <typeparam name="T">The type of the objects to validate.</typeparam>
/// <param name="value">The argument to validate as greater than or equal to <paramref name="other"/>.</param>
/// <param name="other">The value to compare with <paramref name="value"/>.</param>
/// <param name="paramName">The name of the parameter with which <paramref name="value"/> corresponds.</param>
/// <exception cref="ArgumentOutOfRangeException">Thrown if <paramref name="value"/> is less than <paramref name="other"/>.</exception>
internal static void ThrowIfLessThan<T>(T value, T other, [CallerArgumentExpression(nameof(value))] string? paramName = null)
where T : IComparable<T>
{
if (value.CompareTo(other) < 0)
{
ThrowLess(value, other, paramName);
}
}
#endif // !NET8_0_OR_GREATER
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment