Skip to content

Instantly share code, notes, and snippets.

@davidfowl
Last active June 6, 2023 08:10
Show Gist options
  • Save davidfowl/77adf581c26c9ab85630e6bd76835a1c to your computer and use it in GitHub Desktop.
Save davidfowl/77adf581c26c9ab85630e6bd76835a1c to your computer and use it in GitHub Desktop.
An implementation of MessagePipe. Something like a channel but with buffer management so you can peek and advance the message that was read.
using System.Buffers;
using System.Net.WebSockets;
var builder = WebApplication.CreateBuilder(args);
var app = builder.Build();
app.MapGet("/ws", async (HttpContext context) =>
{
const int MaxMessageSize = 1024 * 1024;
var ws = await context.WebSockets.AcceptWebSocketAsync();
var incomingMessages = new MessagePipe<WebSocketMessageType>(MaxMessageSize);
var outgoingMessages = new MessagePipe<(WebSocketMessageType messageType, bool endOfMessage)>(MaxMessageSize);
_ = Task.Run(async () =>
{
while (true)
{
var message = await outgoingMessages.ReadAsync();
var result = message.Result;
var buffer = result.Buffer;
var (messageType, endOfMessage) = message.Metadata;
if (buffer.IsSingleSegment)
{
await ws.SendAsync(buffer.First, messageType, endOfMessage, cancellationToken: default);
}
else
{
var position = buffer.Start;
// Get a segment before the loop so we can be one segment behind while writing
// This allows us to do a non-zero byte write for the endOfMessage = true send
buffer.TryGet(ref position, out var prevSegment);
while (buffer.TryGet(ref position, out var segment))
{
await ws.SendAsync(prevSegment, messageType, endOfMessage: false, default);
prevSegment = segment;
}
// End of message frame
await ws.SendAsync(prevSegment, messageType, endOfMessage, default);
}
outgoingMessages.AdvanceReader();
}
});
_ = Task.Run(async () =>
{
while (true)
{
var message = await incomingMessages.ReadAsync();
var result = message.Result;
var buffer = result.Buffer;
var messageType = message.Metadata;
if (!buffer.IsEmpty)
{
await ProcessMessageAsync(outgoingMessages, result.Buffer, messageType);
}
if (result.IsCompleted)
{
break;
}
incomingMessages.AdvanceReader();
}
await incomingMessages.CompleteReaderAsync();
});
while (true)
{
var result = await ws.ReceiveAsync(incomingMessages.GetMemory(512), default);
if (result.MessageType == WebSocketMessageType.Close)
{
break;
}
incomingMessages.AdvanceWriter(result.Count);
if (result.EndOfMessage)
{
await incomingMessages.FlushMessageAsync(result.MessageType);
}
}
await incomingMessages.CompleteWriterAsync();
});
Task ProcessMessageAsync(MessagePipe<(WebSocketMessageType, bool)> outgoingMessages, ReadOnlySequence<byte> buffer, WebSocketMessageType message)
{
// Process the message here
return Task.CompletedTask;
}
app.Run();
using System.Buffers;
using System.IO.Pipelines;
using System.Runtime.CompilerServices;
/// <summary>
/// Wrapper over a <see cref="System.IO.Pipelines.Pipe" /> that allow associating metadata with the underlying bytes.
/// </summary>
/// <typeparam name="TMetadata">The metadata associated with each message</typeparam>
class MessagePipe<TMetadata>
{
private readonly Pipe _pipe;
private readonly Queue<MessageMetadata> _messageMetadata = new();
private readonly int _maxMessageSize;
private ReadResult _result;
private object Sync => _messageMetadata;
public MessagePipe(int maxMessageSize, PipeOptions? pipeOptions = null)
{
_maxMessageSize = maxMessageSize;
_pipe = pipeOptions switch
{
null => new Pipe(),
PipeOptions o => new Pipe(o)
};
}
public async ValueTask<Message<TMetadata>> ReadAsync(CancellationToken cancellationToken = default)
{
_result = await _pipe.Reader.ReadAsync(cancellationToken);
lock (Sync)
{
if (!_messageMetadata.TryPeek(out var metadata))
{
// REVIEW: How can this happen, must be a bug
return default;
}
// Grab the message from the pipe and slice it
return new(new ReadResult(_result.Buffer.Slice(0, metadata.Size), _result.IsCanceled, _result.IsCompleted), metadata.Metadata);
}
}
public void AdvanceReader()
{
lock (Sync)
{
// Remove the metadata since we're done with the message
if (!_messageMetadata.TryDequeue(out var metadata))
{
// Noop
return;
}
// Advance a single message
_pipe.Reader.AdvanceTo(_result.Buffer.GetPosition(metadata.Size));
}
}
public void CancelPendingRead() => _pipe.Reader.CancelPendingRead();
public Memory<byte> GetMemory(int sizeHint = 0) => _pipe.Writer.GetMemory(sizeHint);
public Span<byte> GetSpan(int sizeHint = 0) => _pipe.Writer.GetSpan(sizeHint);
public void AdvanceWriter(int bytes)
{
_pipe.Writer.Advance(bytes);
var size = _pipe.Writer.UnflushedBytes;
if (size > _maxMessageSize)
{
throw new InvalidOperationException($"Maximum message size of {size} exceeded.");
}
}
public ValueTask CompleteReaderAsync() => _pipe.Reader.CompleteAsync();
public ValueTask CompleteWriterAsync() => _pipe.Writer.CompleteAsync();
public ValueTask WriteMessageAsync(ReadOnlyMemory<byte> buffer, TMetadata metadata, CancellationToken cancellationToken = default)
{
lock (Sync)
{
_pipe.Writer.Write(buffer.Span);
return FlushMessageAsync(buffer.Length, metadata, cancellationToken);
}
}
public ValueTask FlushMessageAsync(long size, TMetadata metadata, CancellationToken cancellationToken = default)
{
if (size == 0)
{
return ValueTask.CompletedTask;
}
if (size > _maxMessageSize)
{
throw new InvalidOperationException($"Maximum message size of {size} exceeded.");
}
lock (Sync)
{
_messageMetadata.Enqueue(new MessageMetadata(size, metadata));
return GetAsValueTask(_pipe.Writer.FlushAsync(cancellationToken));
}
}
public ValueTask FlushMessageAsync(TMetadata metadata, CancellationToken cancellationToken = default)
{
// Use unflushed bytes as the message frame
return FlushMessageAsync(_pipe.Writer.UnflushedBytes, metadata, cancellationToken);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ValueTask GetAsValueTask(in ValueTask<FlushResult> valueTask)
{
// Try to avoid the allocation from AsTask
if (valueTask.IsCompletedSuccessfully)
{
// Signal consumption to the IValueTaskSource
valueTask.GetAwaiter().GetResult();
return default;
}
else
{
return new ValueTask(valueTask.AsTask());
}
}
struct MessageMetadata
{
public long Size { get; }
public TMetadata Metadata { get; }
public MessageMetadata(long size, TMetadata metadata)
{
Size = size;
Metadata = metadata;
}
}
}
readonly struct Message<TMetadata>
{
public ReadResult Result { get; }
public TMetadata Metadata { get; }
public Message(ReadResult readResult, TMetadata metadata)
{
Result = readResult;
Metadata = metadata;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment