Last active
June 6, 2023 08:10
-
-
Save davidfowl/77adf581c26c9ab85630e6bd76835a1c to your computer and use it in GitHub Desktop.
An implementation of MessagePipe. Something like a channel but with buffer management so you can peek and advance the message that was read.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Buffers; | |
using System.Net.WebSockets; | |
var builder = WebApplication.CreateBuilder(args); | |
var app = builder.Build(); | |
app.MapGet("/ws", async (HttpContext context) => | |
{ | |
const int MaxMessageSize = 1024 * 1024; | |
var ws = await context.WebSockets.AcceptWebSocketAsync(); | |
var incomingMessages = new MessagePipe<WebSocketMessageType>(MaxMessageSize); | |
var outgoingMessages = new MessagePipe<(WebSocketMessageType messageType, bool endOfMessage)>(MaxMessageSize); | |
_ = Task.Run(async () => | |
{ | |
while (true) | |
{ | |
var message = await outgoingMessages.ReadAsync(); | |
var result = message.Result; | |
var buffer = result.Buffer; | |
var (messageType, endOfMessage) = message.Metadata; | |
if (buffer.IsSingleSegment) | |
{ | |
await ws.SendAsync(buffer.First, messageType, endOfMessage, cancellationToken: default); | |
} | |
else | |
{ | |
var position = buffer.Start; | |
// Get a segment before the loop so we can be one segment behind while writing | |
// This allows us to do a non-zero byte write for the endOfMessage = true send | |
buffer.TryGet(ref position, out var prevSegment); | |
while (buffer.TryGet(ref position, out var segment)) | |
{ | |
await ws.SendAsync(prevSegment, messageType, endOfMessage: false, default); | |
prevSegment = segment; | |
} | |
// End of message frame | |
await ws.SendAsync(prevSegment, messageType, endOfMessage, default); | |
} | |
outgoingMessages.AdvanceReader(); | |
} | |
}); | |
_ = Task.Run(async () => | |
{ | |
while (true) | |
{ | |
var message = await incomingMessages.ReadAsync(); | |
var result = message.Result; | |
var buffer = result.Buffer; | |
var messageType = message.Metadata; | |
if (!buffer.IsEmpty) | |
{ | |
await ProcessMessageAsync(outgoingMessages, result.Buffer, messageType); | |
} | |
if (result.IsCompleted) | |
{ | |
break; | |
} | |
incomingMessages.AdvanceReader(); | |
} | |
await incomingMessages.CompleteReaderAsync(); | |
}); | |
while (true) | |
{ | |
var result = await ws.ReceiveAsync(incomingMessages.GetMemory(512), default); | |
if (result.MessageType == WebSocketMessageType.Close) | |
{ | |
break; | |
} | |
incomingMessages.AdvanceWriter(result.Count); | |
if (result.EndOfMessage) | |
{ | |
await incomingMessages.FlushMessageAsync(result.MessageType); | |
} | |
} | |
await incomingMessages.CompleteWriterAsync(); | |
}); | |
Task ProcessMessageAsync(MessagePipe<(WebSocketMessageType, bool)> outgoingMessages, ReadOnlySequence<byte> buffer, WebSocketMessageType message) | |
{ | |
// Process the message here | |
return Task.CompletedTask; | |
} | |
app.Run(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Buffers; | |
using System.IO.Pipelines; | |
using System.Runtime.CompilerServices; | |
/// <summary> | |
/// Wrapper over a <see cref="System.IO.Pipelines.Pipe" /> that allow associating metadata with the underlying bytes. | |
/// </summary> | |
/// <typeparam name="TMetadata">The metadata associated with each message</typeparam> | |
class MessagePipe<TMetadata> | |
{ | |
private readonly Pipe _pipe; | |
private readonly Queue<MessageMetadata> _messageMetadata = new(); | |
private readonly int _maxMessageSize; | |
private ReadResult _result; | |
private object Sync => _messageMetadata; | |
public MessagePipe(int maxMessageSize, PipeOptions? pipeOptions = null) | |
{ | |
_maxMessageSize = maxMessageSize; | |
_pipe = pipeOptions switch | |
{ | |
null => new Pipe(), | |
PipeOptions o => new Pipe(o) | |
}; | |
} | |
public async ValueTask<Message<TMetadata>> ReadAsync(CancellationToken cancellationToken = default) | |
{ | |
_result = await _pipe.Reader.ReadAsync(cancellationToken); | |
lock (Sync) | |
{ | |
if (!_messageMetadata.TryPeek(out var metadata)) | |
{ | |
// REVIEW: How can this happen, must be a bug | |
return default; | |
} | |
// Grab the message from the pipe and slice it | |
return new(new ReadResult(_result.Buffer.Slice(0, metadata.Size), _result.IsCanceled, _result.IsCompleted), metadata.Metadata); | |
} | |
} | |
public void AdvanceReader() | |
{ | |
lock (Sync) | |
{ | |
// Remove the metadata since we're done with the message | |
if (!_messageMetadata.TryDequeue(out var metadata)) | |
{ | |
// Noop | |
return; | |
} | |
// Advance a single message | |
_pipe.Reader.AdvanceTo(_result.Buffer.GetPosition(metadata.Size)); | |
} | |
} | |
public void CancelPendingRead() => _pipe.Reader.CancelPendingRead(); | |
public Memory<byte> GetMemory(int sizeHint = 0) => _pipe.Writer.GetMemory(sizeHint); | |
public Span<byte> GetSpan(int sizeHint = 0) => _pipe.Writer.GetSpan(sizeHint); | |
public void AdvanceWriter(int bytes) | |
{ | |
_pipe.Writer.Advance(bytes); | |
var size = _pipe.Writer.UnflushedBytes; | |
if (size > _maxMessageSize) | |
{ | |
throw new InvalidOperationException($"Maximum message size of {size} exceeded."); | |
} | |
} | |
public ValueTask CompleteReaderAsync() => _pipe.Reader.CompleteAsync(); | |
public ValueTask CompleteWriterAsync() => _pipe.Writer.CompleteAsync(); | |
public ValueTask WriteMessageAsync(ReadOnlyMemory<byte> buffer, TMetadata metadata, CancellationToken cancellationToken = default) | |
{ | |
lock (Sync) | |
{ | |
_pipe.Writer.Write(buffer.Span); | |
return FlushMessageAsync(buffer.Length, metadata, cancellationToken); | |
} | |
} | |
public ValueTask FlushMessageAsync(long size, TMetadata metadata, CancellationToken cancellationToken = default) | |
{ | |
if (size == 0) | |
{ | |
return ValueTask.CompletedTask; | |
} | |
if (size > _maxMessageSize) | |
{ | |
throw new InvalidOperationException($"Maximum message size of {size} exceeded."); | |
} | |
lock (Sync) | |
{ | |
_messageMetadata.Enqueue(new MessageMetadata(size, metadata)); | |
return GetAsValueTask(_pipe.Writer.FlushAsync(cancellationToken)); | |
} | |
} | |
public ValueTask FlushMessageAsync(TMetadata metadata, CancellationToken cancellationToken = default) | |
{ | |
// Use unflushed bytes as the message frame | |
return FlushMessageAsync(_pipe.Writer.UnflushedBytes, metadata, cancellationToken); | |
} | |
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |
public static ValueTask GetAsValueTask(in ValueTask<FlushResult> valueTask) | |
{ | |
// Try to avoid the allocation from AsTask | |
if (valueTask.IsCompletedSuccessfully) | |
{ | |
// Signal consumption to the IValueTaskSource | |
valueTask.GetAwaiter().GetResult(); | |
return default; | |
} | |
else | |
{ | |
return new ValueTask(valueTask.AsTask()); | |
} | |
} | |
struct MessageMetadata | |
{ | |
public long Size { get; } | |
public TMetadata Metadata { get; } | |
public MessageMetadata(long size, TMetadata metadata) | |
{ | |
Size = size; | |
Metadata = metadata; | |
} | |
} | |
} | |
readonly struct Message<TMetadata> | |
{ | |
public ReadResult Result { get; } | |
public TMetadata Metadata { get; } | |
public Message(ReadResult readResult, TMetadata metadata) | |
{ | |
Result = readResult; | |
Metadata = metadata; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment