Last active
May 20, 2025 21:26
-
-
Save koturn/6d4ed98774d94d2a786ffe5628b95d49 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Threading; | |
using System.Threading.Tasks; | |
#if !NET9_0_OR_GREATER | |
using Lock = object; | |
#endif // !NET9_0_OR_GREATER | |
namespace Koturn.Tasks | |
{ | |
/// <summary> | |
/// Provides a task scheduler that ensures a maximum concurrency level while running on top of the thread pool. | |
/// </summary> | |
/// <remarks> | |
/// <seealso href="https://learn.microsoft.com/en-us/dotnet/api/system.threading.tasks.taskscheduler?view=net-9.0"/> | |
/// </remarks> | |
public sealed class LimitedConcurrencyLevelTaskScheduler : TaskScheduler | |
{ | |
/// <summary> | |
/// Indicates whether the current thread is processing work items. | |
/// </summary> | |
[ThreadStatic] | |
private static bool _currentThreadIsProcessingItems; | |
/// <summary> | |
/// The maximum concurrency level allowed by this scheduler. | |
/// </summary> | |
public sealed override int MaximumConcurrencyLevel => _maxDegreeOfParallelism; | |
/// <summary> | |
/// <para>An alias of <see cref="MaximumConcurrencyLevel"/>.</para> | |
/// <para>This property can be changed while it cannot be at <see cref="MaximumConcurrencyLevel"/>.</para> | |
/// </summary> | |
public int MaxDegreeOfParallelism | |
{ | |
get => _maxDegreeOfParallelism; | |
set | |
{ | |
#if NET8_0_OR_GREATER | |
ArgumentOutOfRangeException.ThrowIfLessThan(value, 1); | |
#else | |
ThrowIfLessThan(maxDegreeOfParallelism, 1); | |
#endif // NET8_0_OR_GREATER | |
var maxDegreeOfParallelismOld = _maxDegreeOfParallelism; | |
_maxDegreeOfParallelism = value; | |
if (maxDegreeOfParallelismOld >= value) | |
{ | |
return; | |
} | |
lock (_taskListLock) | |
{ | |
var diff = Math.Min(value, _tasks.Count) - _delegatesQueuedOrRunning; | |
if (diff > 0) | |
{ | |
_delegatesQueuedOrRunning += diff; | |
for (int i = 0; i < diff; i++) | |
{ | |
NotifyThreadPoolOfPendingWork(); | |
} | |
} | |
} | |
} | |
} | |
/// <summary> | |
/// The list of tasks to be executed. | |
/// </summary> | |
private readonly LinkedList<Task> _tasks = new(); | |
/// <summary> | |
/// <para>Lock object for <see cref="_tasks"/> and <see cref="_delegatesQueuedOrRunning"/>.</para> | |
/// <para>This variable locked by <see cref="_taskListLock"/>.</para> | |
/// </summary> | |
private readonly Lock _taskListLock = new(); | |
/// <summary> | |
/// The maximum concurrency level allowed by this scheduler. | |
/// </summary> | |
private int _maxDegreeOfParallelism; | |
/// <summary> | |
/// <para>Indicates whether the scheduler is currently processing work items.</para> | |
/// <para>This variable locked by <see cref="_taskListLock"/>.</para> | |
/// </summary> | |
private int _delegatesQueuedOrRunning = 0; | |
/// <summary> | |
/// Creates a new instance with the specified degree of parallelism. | |
/// </summary> | |
/// <param name="maxDegreeOfParallelism"></param> | |
public LimitedConcurrencyLevelTaskScheduler(int maxDegreeOfParallelism) | |
: base() | |
{ | |
_maxDegreeOfParallelism = maxDegreeOfParallelism; | |
} | |
/// <summary> | |
/// Queues a task to the scheduler. | |
/// </summary> | |
/// <param name="task">A task.</param> | |
protected sealed override void QueueTask(Task task) | |
{ | |
// Add the task to the list of tasks to be processed. If there aren't enough | |
// delegates currently queued or running to process tasks, schedule another. | |
lock (_taskListLock) | |
{ | |
_tasks.AddLast(task); | |
if (_delegatesQueuedOrRunning < _maxDegreeOfParallelism) | |
{ | |
_delegatesQueuedOrRunning++; | |
NotifyThreadPoolOfPendingWork(); | |
} | |
} | |
} | |
/// <summary> | |
/// Inform the ThreadPool that there's work to be executed for this scheduler. | |
/// </summary> | |
private void NotifyThreadPoolOfPendingWork() | |
{ | |
ThreadPool.UnsafeQueueUserWorkItem(_ => | |
{ | |
// Note that the current thread is now processing work items. | |
// This is necessary to enable inlining of tasks into this thread. | |
_currentThreadIsProcessingItems = true; | |
try | |
{ | |
// Process all available items in the queue. | |
while (true) | |
{ | |
Task item; | |
lock (_taskListLock) | |
{ | |
if (_delegatesQueuedOrRunning > _maxDegreeOfParallelism) | |
{ | |
_delegatesQueuedOrRunning--; | |
break; | |
} | |
// When there are no more items to be processed, | |
// note that we're done processing, and get out. | |
var task = _tasks.First; | |
if (task is null) | |
{ | |
_delegatesQueuedOrRunning--; | |
break; | |
} | |
// Get the next item from the queue | |
item = task.Value; | |
_tasks.RemoveFirst(); | |
} | |
// Execute the task we pulled out of the queue | |
TryExecuteTask(item); | |
} | |
} | |
// We're done processing items on the current thread | |
finally | |
{ | |
_currentThreadIsProcessingItems = false; | |
} | |
}, null); | |
} | |
/// <summary> | |
/// Attempts to execute the specified task on the current thread. | |
/// </summary> | |
/// <param name="task">A task to execute.</param> | |
/// <param name="taskWasPreviouslyQueued">A flag whether <paramref name="task"/> is queued previously.</param> | |
/// <returns><c>true</c> if task was successfully executed, <c>false</c> if it was not.</returns> | |
protected sealed override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued) | |
{ | |
// If this thread isn't already processing a task, we don't support inlining | |
// If the task was previously queued, remove it from the queue | |
return _currentThreadIsProcessingItems | |
&& (!taskWasPreviouslyQueued || TryDequeue(task)) | |
&& TryExecuteTask(task); | |
} | |
/// <summary> | |
/// Attempt to remove a previously scheduled task from the scheduler. | |
/// </summary> | |
/// <param name="task">A task to remove from the scheduler</param> | |
/// <returns><c>true</c> if the element containing value is successfully removed; otherwise, <c>false</c>.</returns> | |
protected sealed override bool TryDequeue(Task task) | |
{ | |
lock (_taskListLock) | |
{ | |
return _tasks.Remove(task); | |
} | |
} | |
/// <summary> | |
/// Gets an enumerable of the tasks currently scheduled on this scheduler. | |
/// </summary> | |
/// <returns>An enumerator of tasks.</returns> | |
protected sealed override IEnumerable<Task> GetScheduledTasks() | |
{ | |
#if NET9_0_OR_GREATER | |
var lockTaken = false; | |
try | |
{ | |
lockTaken = _taskListLock.TryEnter(); | |
if (!lockTaken) | |
{ | |
throw new NotSupportedException(); | |
} | |
return _tasks; | |
} | |
finally | |
{ | |
if (lockTaken) | |
{ | |
_taskListLock.Exit(); | |
} | |
} | |
#else | |
var lockTaken = false; | |
try | |
{ | |
Monitor.TryEnter(_taskListLock, ref lockTaken); | |
if (!lockTaken) | |
{ | |
throw new NotSupportedException(); | |
} | |
return _tasks; | |
} | |
finally | |
{ | |
if (lockTaken) | |
{ | |
Monitor.Exit(_taskListLock); | |
} | |
} | |
#endif // NET9_0_OR_GREATER | |
} | |
#if !NET8_0_OR_GREATER | |
/// <summary> | |
/// Throw <see cref="ArgumentOutOfRangeException"/>. | |
/// </summary> | |
/// <typeparam name="T">The type of the objects.</typeparam> | |
/// <param name="value">The value of the argument that causes this exception.</param> | |
/// <param name="other">The value to compare with <paramref name="value"/>.</param> | |
/// <param name="paramName">The name of the parameter with which <paramref name="value"/> corresponds.</param> | |
/// <exception cref="ArgumentOutOfRangeException">Always thrown.</exception> | |
[DoesNotReturn] | |
private static void ThrowLess<T>(T value, T other, string? paramName) | |
{ | |
throw new ArgumentOutOfRangeException(paramName, value, $"'{value}' must be greater than or equal to '{other}'."); | |
} | |
/// <summary> | |
/// Throws an <see cref="ArgumentOutOfRangeException"/> if <paramref name="value"/> is less than <paramref name="other"/>. | |
/// </summary> | |
/// <typeparam name="T">The type of the objects to validate.</typeparam> | |
/// <param name="value">The argument to validate as greater than or equal to <paramref name="other"/>.</param> | |
/// <param name="other">The value to compare with <paramref name="value"/>.</param> | |
/// <param name="paramName">The name of the parameter with which <paramref name="value"/> corresponds.</param> | |
/// <exception cref="ArgumentOutOfRangeException">Thrown if <paramref name="value"/> is less than <paramref name="other"/>.</exception> | |
internal static void ThrowIfLessThan<T>(T value, T other, [CallerArgumentExpression(nameof(value))] string? paramName = null) | |
where T : IComparable<T> | |
{ | |
if (value.CompareTo(other) < 0) | |
{ | |
ThrowLess(value, other, paramName); | |
} | |
} | |
#endif // !NET8_0_OR_GREATER | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment