Last active
May 3, 2022 09:07
-
-
Save theodorzoulias/715a19e0dc69bd23143826e23d826a83 to your computer and use it in GitHub Desktop.
PressureAwareUnboundedChannel -- https://stackoverflow.com/a/69284386/11178549
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Immutable; | |
using System.Diagnostics; | |
using System.Threading; | |
using System.Threading.Channels; | |
using System.Threading.Tasks; | |
using System.Threading.Tasks.Dataflow; | |
namespace CustomChannels | |
{ | |
/** | |
<example> | |
<code> | |
var channel = new PressureAwareUnboundedChannel{Item}(500, 1000); | |
var subscription = channel.SubscribeForPressureNotifications(underPressure => | |
{ | |
if (underPressure) Producer.Pause(); else Producer.Resume(); | |
}); | |
// At this point the Producer is owned by the channel | |
//... | |
channel.Writer.Complete(); | |
await channel.Reader.Completion; | |
await subscription.UnsubscribeAsync(); | |
// At this point the Producer is no longer owned by the channel | |
</code> | |
</example> | |
*/ | |
public sealed class PressureAwareUnboundedChannel<T> : Channel<T> | |
{ | |
private readonly Channel<T> _channel; | |
private readonly int _highPressureThreshold; | |
private readonly int _lowPressureThreshold; | |
private ImmutableArray<Subscription> _subscribers; | |
private bool _writerCompleted = false; | |
private bool _underPressure = false; | |
private int _count = 0; | |
public interface ISubscription { public Task UnsubscribeAsync(); } | |
private class Subscription : ISubscription | |
{ | |
private enum SecondStep { Include, Skip } | |
private readonly PressureAwareUnboundedChannel<T> _parent; | |
private readonly ActionBlock<Task<SecondStep>> _block; | |
private TaskCompletionSource<SecondStep> _pending; | |
public Subscription(PressureAwareUnboundedChannel<T> parent, | |
Func<bool, Task> action, TaskScheduler scheduler = null) | |
{ | |
_parent = parent; | |
_block = new ActionBlock<Task<SecondStep>>(async secondStep => | |
{ | |
if (Volatile.Read(ref _parent._writerCompleted)) return; | |
await action(true); // Emit "under pressure" | |
if (await secondStep == SecondStep.Skip) return; | |
if (Volatile.Read(ref _parent._writerCompleted)) return; | |
await action(false); // Emit "pressure released" | |
}, new() | |
{ | |
BoundedCapacity = 1, | |
TaskScheduler = scheduler ?? TaskScheduler.Default | |
}); | |
} | |
public void Post(bool underPressure) | |
{ | |
Debug.Assert(Monitor.IsEntered(_parent._channel)); | |
if (underPressure) | |
{ | |
var tcs = new TaskCompletionSource<SecondStep>( | |
TaskCreationOptions.RunContinuationsAsynchronously); | |
if (_block.Post(tcs.Task)) | |
{ | |
_pending?.TrySetResult(SecondStep.Skip); | |
_pending = tcs; | |
} | |
} | |
else | |
{ | |
_pending?.TrySetResult(SecondStep.Include); | |
_pending = null; | |
} | |
} | |
public void Complete() | |
{ | |
Debug.Assert(Monitor.IsEntered(_parent._channel)); | |
_block.Complete(); | |
_pending?.TrySetResult(SecondStep.Skip); | |
_pending = null; | |
} | |
public Task UnsubscribeAsync() | |
{ | |
lock (_parent._channel) | |
{ | |
int index = _parent._subscribers.IndexOf(this); | |
if (index >= 0) | |
{ | |
_parent._subscribers = _parent._subscribers.RemoveAt(index); | |
_block.Complete(); | |
_pending?.TrySetResult(SecondStep.Skip); | |
_pending = null; | |
} | |
return _block.Completion; | |
} | |
} | |
} | |
public ISubscription SubscribeForPressureNotifications(Func<bool, Task> action, | |
TaskScheduler scheduler = null) | |
{ | |
lock (_channel) | |
{ | |
var subscriber = new Subscription(this, action, scheduler); | |
if (_writerCompleted) | |
subscriber.Complete(); | |
else | |
_subscribers = _subscribers.Add(subscriber); | |
return subscriber; | |
} | |
} | |
public ISubscription SubscribeForPressureNotifications(Action<bool> action, | |
TaskScheduler scheduler = null) | |
=> SubscribeForPressureNotifications( | |
e => { action(e); return Task.CompletedTask; }, scheduler); | |
public PressureAwareUnboundedChannel(int lowPressureThreshold, | |
int highPressureThreshold, TaskScheduler eventTaskScheduler = null) | |
{ | |
if (highPressureThreshold < lowPressureThreshold) | |
throw new ArgumentOutOfRangeException(nameof(highPressureThreshold)); | |
if (lowPressureThreshold < 0) | |
throw new ArgumentOutOfRangeException(nameof(lowPressureThreshold)); | |
_highPressureThreshold = highPressureThreshold; | |
_lowPressureThreshold = lowPressureThreshold; | |
_channel = Channel.CreateUnbounded<T>(); | |
_subscribers = ImmutableArray.Create<Subscription>(); | |
this.Writer = new ChannelWriter(this); | |
this.Reader = new ChannelReader(this); | |
} | |
private class ChannelWriter : ChannelWriter<T> | |
{ | |
private readonly PressureAwareUnboundedChannel<T> _parent; | |
public ChannelWriter(PressureAwareUnboundedChannel<T> parent) | |
=> _parent = parent; | |
public override bool TryComplete(Exception error = null) | |
{ | |
bool success = _parent._channel.Writer.TryComplete(error); | |
if (success) _parent.Complete(); | |
return success; | |
} | |
public override bool TryWrite(T item) | |
{ | |
bool success = _parent._channel.Writer.TryWrite(item); | |
if (success) _parent.SignalWriteOrRead(1); | |
return success; | |
} | |
public override ValueTask<bool> WaitToWriteAsync( | |
CancellationToken cancellationToken = default) | |
=> _parent._channel.Writer.WaitToWriteAsync(cancellationToken); | |
} | |
private class ChannelReader : ChannelReader<T> | |
{ | |
private readonly PressureAwareUnboundedChannel<T> _parent; | |
public ChannelReader(PressureAwareUnboundedChannel<T> parent) | |
=> _parent = parent; | |
public override Task Completion => _parent._channel.Reader.Completion; | |
public override bool CanCount => true; | |
public override int Count => Volatile.Read(ref _parent._count); | |
public override bool TryRead(out T item) | |
{ | |
bool success = _parent._channel.Reader.TryRead(out item); | |
if (success) _parent.SignalWriteOrRead(-1); | |
return success; | |
} | |
public override ValueTask<bool> WaitToReadAsync( | |
CancellationToken cancellationToken = default) | |
=> _parent._channel.Reader.WaitToReadAsync(cancellationToken); | |
} | |
private void Complete() | |
{ | |
lock (_channel) | |
{ | |
if (_writerCompleted) return; | |
_writerCompleted = true; | |
foreach (var subscriber in _subscribers) subscriber.Complete(); | |
_subscribers = _subscribers.Clear(); | |
} | |
} | |
private void SignalWriteOrRead(int countDelta) | |
{ | |
lock (_channel) | |
{ | |
_count += countDelta; | |
bool underPressure; | |
if (_count > _highPressureThreshold) | |
underPressure = true; | |
else if (_count <= _lowPressureThreshold) | |
underPressure = false; | |
else | |
return; | |
if (underPressure == _underPressure) return; | |
_underPressure = underPressure; | |
foreach (var subscriber in _subscribers) subscriber.Post(underPressure); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment