Created
January 20, 2025 14:09
-
-
Save rodion-m/e648ca00f29df4ae97b4c143a7cc9a3f to your computer and use it in GitHub Desktop.
DynamicParallelism.ForEachAsync from DeepSeek-R1
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Collections.Concurrent; | |
using System.Threading.Channels; | |
public static class DynamicParallelism | |
{ | |
public static async Task ForEachAsync<TSource>( | |
IEnumerable<TSource> source, | |
Func<TSource, CancellationToken, Task> bodyAsync, | |
Func<int> getDegreeOfParallelism, | |
bool preserveOrder = false, | |
CancellationToken cancellationToken = default) | |
{ | |
var channel = Channel.CreateUnbounded<Item<TSource>>(new UnboundedChannelOptions | |
{ | |
SingleWriter = true, | |
AllowSynchronousContinuations = false | |
}); | |
var writer = channel.Writer; | |
var reader = channel.Reader; | |
var dopLock = new object(); | |
var workers = new List<Task>(); | |
var exceptions = new ConcurrentQueue<Exception>(); | |
var index = 0; | |
// Initial worker count | |
var currentDop = Math.Max(1, getDegreeOfParallelism()); | |
// Writer task (single producer) | |
var writingTask = Task.Run(async () => | |
{ | |
await foreach (var item in source.ToAsyncEnumerable().WithCancellation(cancellationToken)) | |
{ | |
await writer.WriteAsync(new Item<TSource>(index++, item), cancellationToken); | |
} | |
writer.Complete(); | |
}, cancellationToken); | |
// Controller task for dynamic parallelism | |
var controllerTask = Task.Run(async () => | |
{ | |
while (!cancellationToken.IsCancellationRequested) | |
{ | |
try | |
{ | |
var newDop = Math.Max(1, getDegreeOfParallelism()); | |
lock (dopLock) | |
{ | |
// Scale up | |
while (workers.Count < newDop && !cancellationToken.IsCancellationRequested) | |
{ | |
var workerId = workers.Count; | |
var workerCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); | |
workers.Add(ProcessItemsAsync(workerId, reader, bodyAsync, workerCts.Token)); | |
} | |
// Scale down | |
while (workers.Count > newDop && !cancellationToken.IsCancellationRequested) | |
{ | |
var lastWorker = workers[^1]; | |
workers.RemoveAt(workers.Count - 1); | |
} | |
} | |
await Task.Delay(100, cancellationToken); // Adjust scale check interval | |
} | |
catch (OperationCanceledException) | |
{ | |
break; | |
} | |
} | |
}, cancellationToken); | |
// Process items with dynamic worker count | |
async Task ProcessItemsAsync( | |
int workerId, | |
ChannelReader<Item<TSource>> reader, | |
Func<TSource, CancellationToken, Task> processor, | |
CancellationToken ct) | |
{ | |
try | |
{ | |
await foreach (var item in reader.ReadAllAsync(ct)) | |
{ | |
try | |
{ | |
await processor(item.Value, ct); | |
} | |
catch (Exception ex) when (ex is not OperationCanceledException) | |
{ | |
exceptions.Enqueue(ex); | |
break; | |
} | |
} | |
} | |
catch (OperationCanceledException) | |
{ | |
// Expected during scale-down | |
} | |
} | |
try | |
{ | |
await Task.WhenAll(writingTask, controllerTask); | |
await Task.WhenAll(workers); | |
} | |
catch (OperationCanceledException) | |
{ | |
// Normal shutdown | |
} | |
if (exceptions.TryDequeue(out var firstEx)) | |
{ | |
throw new AggregateException(exceptions).Flatten(); | |
} | |
} | |
private readonly record struct Item<TSource>(int Index, TSource Value); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment