Last active
May 15, 2024 17:37
-
-
Save dlyz/1c2f892e482f599093bdb9021e20c26f to your computer and use it in GitHub Desktop.
Server Sent Events parser impl close to proposal in https://github.com/dotnet/runtime/issues/98105
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Buffers; | |
using System.Runtime.CompilerServices; | |
using System.Runtime.InteropServices; | |
using System.Text; | |
public readonly record struct SseItem<T>( | |
string EventType, | |
T Data, | |
string LastEventId, | |
TimeSpan ReconnectionInterval | |
); | |
public delegate T EventDataParser<out T>(string eventType, ReadOnlySpan<byte> data); | |
/// <summary> | |
/// See https://github.com/dotnet/runtime/issues/98105 | |
/// https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation | |
/// </summary> | |
public static class SseParser | |
{ | |
public const string DefaultEventType = "message"; | |
public static async IAsyncEnumerable<SseItem<T>> Parse<T>( | |
Stream sseStream, | |
EventDataParser<T> dataParser, | |
[EnumeratorCancellation] CancellationToken cancellationToken = default | |
) | |
{ | |
using var buffer = new Buffer(); | |
var firstEvent = true; | |
var inspectedLength = 0; | |
var lastEventId = string.Empty; | |
var reconnectionInterval = Timeout.InfiniteTimeSpan; | |
int bytesRead; | |
do | |
{ | |
bytesRead = await sseStream.ReadAsync(buffer.GetReadTargetMemory(), cancellationToken); | |
buffer.Advance(bytesRead); | |
var bufferStart = 0; | |
// we should perform at least one roundtrip | |
// even if none bytes read and whole buffer is inspected, | |
// because streamCompleted is changed in this case | |
do | |
{ | |
var eventEnd = IndexOfEndOfEvent( | |
inspectedLength - bufferStart, | |
buffer.Span[bufferStart..], | |
streamCompleted: bytesRead <= 0, | |
out var firstNLSize, | |
out var secondNLSize | |
); | |
if (eventEnd == -1) | |
{ | |
inspectedLength = buffer.Length; | |
break; | |
} | |
// first newline is a part of an event | |
eventEnd += bufferStart + firstNLSize; | |
// skipping BOM | |
var eventStart = bufferStart + (firstEvent && buffer.Span.StartsWith("\uFEFF"u8) ? "\uFEFF"u8.Length : 0); | |
firstEvent = false; | |
var @event = ParseEvent( | |
buffer.Span[eventStart..eventEnd], | |
dataParser, | |
lastEventId: ref lastEventId, | |
reconnectionInterval: ref reconnectionInterval | |
); | |
// including the empty line | |
inspectedLength = bufferStart = eventEnd + secondNLSize; | |
cancellationToken.ThrowIfCancellationRequested(); | |
if (@event is not null) | |
{ | |
yield return @event.Value; | |
} | |
} | |
while (inspectedLength < buffer.Length); | |
buffer.ShiftLeft(bufferStart); | |
inspectedLength -= bufferStart; | |
} | |
while (bytesRead > 0); | |
} | |
private static SseItem<T>? ParseEvent<T>( | |
Span<byte> eventBytes, | |
EventDataParser<T> dataParser, | |
ref string lastEventId, | |
ref TimeSpan reconnectionInterval | |
) | |
{ | |
var hasAtLeastOneDataChunk = false; | |
Span<byte> firstDataChunk = default; | |
var totalDataLength = 0; | |
List<(int From, int Length)>? secondaryDataChunks = null; | |
var currentSpan = eventBytes; | |
string eventType = string.Empty; | |
while (currentSpan.Length != 0) | |
{ | |
var fieldLen = IndexOfNewLine(0, currentSpan, out var nlSize); | |
Debug.Assert(fieldLen != -1); | |
var field = currentSpan[..fieldLen]; | |
currentSpan = currentSpan[(fieldLen + nlSize)..]; | |
var colonIndex = field.IndexOf((byte)':'); | |
Span<byte> fieldName; | |
Span<byte> fieldValue; | |
if (colonIndex == -1) | |
{ | |
fieldName = field; | |
fieldValue = default; | |
} | |
else if (colonIndex == 0) | |
{ | |
// this is a comment | |
continue; | |
} | |
else | |
{ | |
fieldName = field[0..colonIndex]; | |
fieldValue = field[(colonIndex + 1)..]; | |
if (fieldValue.Length != 0 && fieldValue[0] == ' ') | |
{ | |
fieldValue = fieldValue[1..]; | |
} | |
} | |
if (fieldName.SequenceEqual("event"u8)) | |
{ | |
eventType = Encoding.UTF8.GetString(fieldValue); | |
} | |
else if (fieldName.SequenceEqual("data"u8)) | |
{ | |
if (!hasAtLeastOneDataChunk) | |
{ | |
firstDataChunk = fieldValue; | |
totalDataLength += fieldValue.Length; | |
hasAtLeastOneDataChunk = true; | |
} | |
else | |
{ | |
secondaryDataChunks ??= []; | |
secondaryDataChunks.Add(GetNestedSpanBounds(eventBytes, fieldValue)); | |
// extra one of the separator-newline | |
totalDataLength += 1 + fieldValue.Length; | |
} | |
} | |
else if (fieldName.SequenceEqual("id"u8)) | |
{ | |
if (!fieldValue.Contains((byte)'\0')) | |
{ | |
// BROWSER DISPATCH (1): last event id persist until changed explicitly | |
lastEventId = Encoding.UTF8.GetString(fieldValue); | |
} | |
} | |
else if (fieldName.SequenceEqual("retry"u8)) | |
{ | |
if (int.TryParse(fieldValue, out var retry)) | |
{ | |
reconnectionInterval = TimeSpan.FromMilliseconds(retry); | |
} | |
} | |
} | |
// BROWSER DISPATCH (2): empty data -> don't dispatch | |
if (!hasAtLeastOneDataChunk) | |
{ | |
return null; | |
} | |
// BROWSER DISPATCH (6): empty type -> default type | |
if (eventType.Length == 0) | |
{ | |
eventType = DefaultEventType; | |
} | |
T data; | |
// BROWSER DISPATCH (3): data last \n is removed (not added) | |
if (secondaryDataChunks is null) | |
{ | |
data = dataParser(eventType, firstDataChunk); | |
} | |
else | |
{ | |
var dataBuf = ArrayPool<byte>.Shared.Rent(totalDataLength); | |
try | |
{ | |
firstDataChunk.CopyTo(dataBuf.AsSpan()); | |
var dataBufLength = firstDataChunk.Length; | |
foreach (var (from, length) in secondaryDataChunks) | |
{ | |
dataBuf[dataBufLength++] = (byte)'\n'; | |
var dataChunk = eventBytes.Slice(from, length); | |
dataChunk.CopyTo(dataBuf.AsSpan(dataBufLength)); | |
dataBufLength += length; | |
} | |
data = dataParser(DefaultEventType, dataBuf.AsSpan(0, dataBufLength)); | |
} | |
finally | |
{ | |
ArrayPool<byte>.Shared.Return(dataBuf); | |
} | |
} | |
return new(eventType, data, LastEventId: lastEventId, ReconnectionInterval: reconnectionInterval); | |
static (int From, int Length) GetNestedSpanBounds(Span<byte> eventBytes, Span<byte> nested) | |
{ | |
if (nested.IsEmpty) | |
{ | |
// this is corner case because we can not take pointers from empty buffers | |
return (0, 0); | |
} | |
else | |
{ | |
var eventByteOffset = (int)Unsafe.ByteOffset( | |
ref MemoryMarshal.GetReference(eventBytes), | |
ref MemoryMarshal.GetReference(nested) | |
); | |
return (eventByteOffset, nested.Length); | |
} | |
} | |
} | |
private static readonly SearchValues<byte> _newlines = SearchValues.Create("\r\n"u8); | |
private static int IndexOfNewLine(int startIndex, Span<byte> bytes, out int size) | |
{ | |
bytes = bytes[startIndex..]; | |
var index = bytes.IndexOfAny(_newlines); | |
if (index == -1) | |
{ | |
size = default; | |
return -1; | |
} | |
if (bytes[index] == '\r' && index + 1 < bytes.Length && bytes[index + 1] == '\n') | |
{ | |
size = 2; | |
} | |
else | |
{ | |
size = 1; | |
} | |
return startIndex + index; | |
} | |
private static bool IsNewline(int index, Span<byte> bytes, out int size) | |
{ | |
if (bytes[index] == '\r') | |
{ | |
if (index + 1 < bytes.Length && bytes[index + 1] == '\n') | |
{ | |
size = 2; | |
} | |
else | |
{ | |
size = 1; | |
} | |
return true; | |
} | |
else if (bytes[index] == '\n') | |
{ | |
size = 1; | |
return true; | |
} | |
else | |
{ | |
size = 0; | |
return false; | |
} | |
} | |
// firstSize will be 0 for an empty event | |
private static int IndexOfEndOfEvent(int inspectedLength, Span<byte> bytes, bool streamCompleted, out int firstSize, out int secondSize) | |
{ | |
// the worst case of buffers split point is \r\n\r | \n | |
var startIndex = Math.Max(0, inspectedLength - 3); | |
while (true) | |
{ | |
var index = IndexOfNewLine(startIndex, bytes, out firstSize); | |
if (index == -1) | |
{ | |
secondSize = 0; | |
return -1; | |
} | |
if (index == 0) | |
{ | |
// empty event, one newline is enough | |
secondSize = firstSize; | |
firstSize = 0; | |
if (!streamCompleted && bytes.Length == 1 && bytes[index] == '\r') | |
{ | |
// newline is a single \r in the end of the buffer. | |
// so it is possible that the \n will follow, so we don't flushing the empty event yet | |
return -1; | |
} | |
return index; | |
} | |
if (index + firstSize >= bytes.Length) | |
{ | |
secondSize = 0; | |
return -1; | |
} | |
if (IsNewline(index + firstSize, bytes, out secondSize)) | |
{ | |
if (!streamCompleted && secondSize == 1 && index + firstSize + secondSize == bytes.Length && bytes[index + firstSize] == '\r') | |
{ | |
// second newline is a single \r in the end of the buffer. | |
// so it is possible that the \n will follow, so we don't flushing the event yet | |
return -1; | |
} | |
return index; | |
} | |
startIndex = index + firstSize + 1; | |
} | |
} | |
private struct Buffer : IDisposable | |
{ | |
private const int _initialSize = 4096; | |
private const int _leastAdditionalSpace = 128; | |
public Buffer() | |
{ | |
_array = ArrayPool<byte>.Shared.Rent(_initialSize); | |
} | |
private int _length; | |
private byte[] _array; | |
private void Grow() | |
{ | |
var newArray = ArrayPool<byte>.Shared.Rent(_array.Length * 2); | |
_array.CopyTo(newArray, 0); | |
ArrayPool<byte>.Shared.Return(_array); | |
_array = newArray; | |
} | |
public Memory<byte> GetReadTargetMemory() | |
{ | |
while (_length + _leastAdditionalSpace > _array.Length) | |
{ | |
Grow(); | |
} | |
return _array.AsMemory(_length); | |
} | |
public void Advance(int count) | |
{ | |
Debug.Assert(_length + count <= _array.Length); | |
_length += count; | |
} | |
public readonly int Length => _length; | |
public readonly Memory<byte> Memory => _array.AsMemory(0, _length); | |
public readonly Span<byte> Span => _array.AsSpan(0, _length); | |
public void ShiftLeft(int count) | |
{ | |
Debug.Assert(count <= _length); | |
if (count == 0) | |
{ | |
return; | |
} | |
var keepSpan = _array.AsSpan(count, _length - count); | |
keepSpan.CopyTo(_array.AsSpan()); | |
_length -= count; | |
} | |
public void Dispose() | |
{ | |
ArrayPool<byte>.Shared.Return(_array); | |
} | |
} | |
} | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Text; | |
using System.Text.Json; | |
using System.Text.Json.Nodes; | |
using var httpClient = new HttpClient() | |
{ | |
BaseAddress = new Uri("https://api.openai.com/v1/"), | |
}; | |
httpClient.DefaultRequestHeaders.Authorization = new("Bearer", Environment.GetEnvironmentVariable("OPENAI_API_KEY")); | |
using var request = new HttpRequestMessage(HttpMethod.Post, "chat/completions") | |
{ | |
Content = new StringContent( | |
""" | |
{ | |
"model": "gpt-4o", | |
"messages": [ | |
{ | |
"role": "system", | |
"content": "You are a helpful assistant." | |
}, | |
{ | |
"role": "user", | |
"content": "Write all steps to implement server sent events parsing from stream in C#" | |
} | |
], | |
"stream": true, | |
"seed": 43, | |
"max_tokens": 2000 | |
} | |
""", | |
Encoding.UTF8, | |
"application/json" | |
) | |
}; | |
using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead); | |
Console.WriteLine("StatusCode: " + response.StatusCode); | |
Console.WriteLine("ContentType: " + response.Content.Headers.ContentType); | |
Console.WriteLine(); | |
if (response.StatusCode != System.Net.HttpStatusCode.OK) | |
{ | |
Console.WriteLine(await response.Content.ReadAsStringAsync()); | |
return; | |
} | |
using var stream = response.Content.ReadAsStream(); | |
var events = SseParser.Parse(stream, (eventType, data) => | |
{ | |
if (data.SequenceEqual("[DONE]"u8)) | |
{ | |
return null; | |
} | |
else | |
{ | |
return JsonSerializer.Deserialize<JsonObject>(data); | |
} | |
}); | |
await foreach (var @event in events) | |
{ | |
if (@event.Data is null) | |
{ | |
Console.WriteLine(); | |
Console.WriteLine("--- completed ---"); | |
break; | |
} | |
Console.Write(@event.Data["choices"]![0]!["delta"]!["content"]); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Text; | |
public class SseParserTest | |
{ | |
[Theory] | |
[InlineData("\n")] | |
[InlineData("\r")] | |
[InlineData("\r\n")] | |
public async Task TestExamplesFromSpec(string lineEndings) | |
{ | |
await Test( | |
""" | |
data: This is the first message. | |
data: This is the second message, it | |
data: has two lines. | |
data: This is the third message. | |
""".ReplaceLineEndings(lineEndings), | |
(_, data) => Encoding.UTF8.GetString(data), | |
[ | |
MakeItem("This is the first message."), | |
MakeItem("This is the second message, it\nhas two lines."), | |
MakeItem("This is the third message."), | |
] | |
); | |
await Test( | |
""" | |
event: add | |
data: 73857293 | |
event: remove | |
data: 2153 | |
event: add | |
data: 113411 | |
""".ReplaceLineEndings(lineEndings), | |
(_, data) => int.Parse(data), | |
[ | |
MakeItem(73857293, eventType: "add"), | |
MakeItem(2153, eventType: "remove"), | |
MakeItem(113411, eventType: "add"), | |
] | |
); | |
await Test( | |
""" | |
data: YHOO | |
data: +2 | |
data: 10 | |
""".ReplaceLineEndings(lineEndings), | |
(_, data) => Encoding.UTF8.GetString(data), | |
[ | |
MakeItem("YHOO\n+2\n10"), | |
] | |
); | |
await Test( | |
""" | |
: test stream | |
data: first event | |
id: 1 | |
data:second event | |
id | |
data: third event | |
""".ReplaceLineEndings(lineEndings), | |
(_, data) => Encoding.UTF8.GetString(data), | |
[ | |
MakeItem("first event", lastEventId: "1"), | |
MakeItem("second event"), | |
MakeItem(" third event"), | |
] | |
); | |
await Test( | |
""" | |
data | |
data | |
data | |
data: | |
data: | |
""".ReplaceLineEndings(lineEndings), | |
(_, data) => Encoding.UTF8.GetString(data), | |
[ | |
MakeItem(""), | |
MakeItem("\n"), | |
MakeItem(""), | |
] | |
); | |
await Test( | |
""" | |
data:test | |
data: test | |
""".ReplaceLineEndings(lineEndings), | |
(_, data) => Encoding.UTF8.GetString(data), | |
[ | |
MakeItem("test"), | |
MakeItem("test"), | |
] | |
); | |
} | |
private static SseItem<T> MakeItem<T>(T data, string? eventType = null, string? lastEventId = null, TimeSpan? reconnectionInterval = null) | |
{ | |
return new(eventType ?? SseParser.DefaultEventType, data, lastEventId ?? string.Empty, reconnectionInterval ?? Timeout.InfiniteTimeSpan); | |
} | |
private static async Task Test<T>(string input, EventDataParser<T> parser, SseItem<T>[] expectedItems) | |
{ | |
var stream = new MemoryStream(); | |
{ | |
var writer = new StreamWriter(stream, Encoding.UTF8); | |
writer.Write(input); | |
writer.Flush(); | |
} | |
stream.Position = 0; | |
var items = await SseParser.Parse(stream, parser).ToArrayAsync(); | |
Assert.Equal(expectedItems.AsEnumerable(), items); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Buffers; | |
using System.Reflection; | |
using System.Text; | |
using System.Text.Json; | |
using Microsoft.AspNetCore.Builder; | |
using Microsoft.AspNetCore.Http; | |
using Microsoft.AspNetCore.Http.Json; | |
using Microsoft.AspNetCore.Http.Metadata; | |
using Microsoft.Extensions.DependencyInjection; | |
using Microsoft.Extensions.Options; | |
public class SseResult<TValue> : IResult, IEndpointMetadataProvider, IStatusCodeHttpResult | |
{ | |
private readonly IAsyncEnumerable<(string? EventType, TValue Value)> _values; | |
private readonly (string? EventType, string Data)? _completionEvent; | |
public SseResult( | |
IAsyncEnumerable<(string? EventType, TValue Value)> values, | |
(string? EventType, string Data)? completionEvent | |
) | |
{ | |
_values = values ?? throw new ArgumentNullException(nameof(values)); | |
if (completionEvent.HasValue) | |
{ | |
if (HasNewlines(completionEvent.Value.EventType)) | |
{ | |
throw new ArgumentException("Event type shall not contain newline chars.", nameof(completionEvent)); | |
} | |
if (HasNewlines(completionEvent.Value.Data)) | |
{ | |
// we may support it later | |
throw new NotSupportedException("Newline chars are not supported in the completion event data."); | |
} | |
} | |
_completionEvent = completionEvent; | |
} | |
private static readonly SearchValues<char> _newlines = SearchValues.Create("\r\n"); | |
private static bool HasNewlines(string? value) | |
{ | |
return value is not null && value.AsSpan().IndexOfAny(_newlines) != -1; | |
} | |
public async Task ExecuteAsync(HttpContext httpContext) | |
{ | |
var response = httpContext.Response; | |
response.Headers.ContentType = _contentType; | |
var jsonOptions = ResolveSerializerOptions(httpContext); | |
var responseStream = response.Body; | |
var buffer = new ArrayBufferWriter<byte>(128); | |
var cancellationToken = httpContext.RequestAborted; | |
await foreach (var (eventType, value) in _values.WithCancellation(cancellationToken).ConfigureAwait(false)) | |
{ | |
if (!string.IsNullOrEmpty(eventType)) | |
{ | |
buffer.Write("event: "u8); | |
if (HasNewlines(eventType)) | |
{ | |
throw new InvalidOperationException("Event type shall not contain newline chars."); | |
} | |
Encoding.UTF8.GetBytes(eventType, buffer); | |
buffer.Write("\n"u8); | |
} | |
buffer.Write("data: "u8); | |
await responseStream.WriteAsync(buffer.WrittenMemory, cancellationToken).ConfigureAwait(false); | |
buffer.ResetWrittenCount(); | |
await JsonSerializer.SerializeAsync(responseStream, value, jsonOptions, cancellationToken).ConfigureAwait(false); | |
buffer.Write("\n\n"u8); | |
await responseStream.WriteAsync(buffer.WrittenMemory, cancellationToken).ConfigureAwait(false); | |
buffer.ResetWrittenCount(); | |
await responseStream.FlushAsync().ConfigureAwait(false); | |
} | |
if (_completionEvent.HasValue) | |
{ | |
var (eventType, data) = _completionEvent.Value; | |
if (!string.IsNullOrEmpty(eventType)) | |
{ | |
buffer.Write("event: "u8); | |
Encoding.UTF8.GetBytes(eventType, buffer); | |
buffer.Write("\n"u8); | |
} | |
buffer.Write("data: "u8); | |
Encoding.UTF8.GetBytes(data, buffer); | |
buffer.Write("\n\n"u8); | |
await responseStream.WriteAsync(buffer.WrittenMemory, cancellationToken).ConfigureAwait(false); | |
} | |
} | |
public int? StatusCode => StatusCodes.Status200OK; | |
private const string _contentType = "text/event-stream"; | |
public static void PopulateMetadata(MethodInfo method, EndpointBuilder builder) | |
{ | |
builder.Metadata.Add(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, typeof(TValue), [_contentType])); | |
} | |
private static JsonSerializerOptions ResolveSerializerOptions(HttpContext httpContext) | |
{ | |
// Attempt to resolve options from DI then fallback to default options | |
var result = httpContext.RequestServices?.GetService<IOptions<JsonOptions>>()?.Value?.SerializerOptions ?? _defaultSerializerOptions; | |
if (result.WriteIndented) | |
{ | |
// SSE doesn't allow newlines without repeating field name | |
return new(result) { WriteIndented = false }; | |
} | |
else | |
{ | |
return result; | |
} | |
} | |
private static readonly JsonSerializerOptions _defaultSerializerOptions = new JsonOptions().SerializerOptions; | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment