Skip to content

Instantly share code, notes, and snippets.

@jnm2
Last active March 8, 2025 23:23
Show Gist options
  • Save jnm2/f5c05af3cbd1fb1def03dc51e6fcbf65 to your computer and use it in GitHub Desktop.
Save jnm2/f5c05af3cbd1fb1def03dc51e6fcbf65 to your computer and use it in GitHub Desktop.
using System.Buffers;
using System.Buffers.Binary;
using System.Collections.Immutable;
using System.IO;
using System.Net;
using System.Net.NetworkInformation;
using System.Net.Sockets;
using System.Text;
using System.Threading.Channels;
/// <summary>
/// Discovers SQL Server instances on the local network asynchronously using <see
/// href="https://learn.microsoft.com/en-us/openspecs/windows_protocols/mc-sqlr"/>/
/// </summary>
public sealed class SqlServerDiscoverer : IAsyncEnumerable<SqlServerInstance>
{
public async IAsyncEnumerator<SqlServerInstance> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
var clients = new List<SsrpClient>();
try
{
var channel = Channel.CreateUnbounded<SqlServerInstance>();
// If more than one interface can broadcast, you must bind to a particular interface before broadcasting or Windows will discard the packet.
foreach (var interfaceAddress in (
from i in NetworkInterface.GetAllNetworkInterfaces()
where i.OperationalStatus == OperationalStatus.Up && !i.IsReceiveOnly && i.SupportsMulticast
let localV4Address = i.GetIPProperties().UnicastAddresses.Select(a => a.Address).SingleOrDefault(a => a.AddressFamily == AddressFamily.InterNetwork)
where localV4Address != null
select localV4Address
).DefaultIfEmpty(IPAddress.Any))
{
var client = new SsrpClient(new IPEndPoint(interfaceAddress, 0));
clients.Add(client);
client.InstancesDiscovered += (_, instances) =>
{
foreach (var instance in instances)
channel.Writer.TryWrite(instance);
};
_ = client.ResolveAsync(cancellationToken);
}
await foreach (var item in channel.Reader.ReadAllAsync(cancellationToken))
{
yield return item;
}
}
finally
{
foreach (var client in clients)
client.Dispose();
}
}
private sealed class SsrpClient : IDisposable
{
private const byte CLNT_BCAST_EX = 0x02;
private const byte SVR_RESP = 0x05;
private readonly Socket socket;
private bool isDisposed;
public event EventHandler<ImmutableArray<SqlServerInstance>>? InstancesDiscovered;
public SsrpClient(IPEndPoint localEndPoint)
{
socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp) { EnableBroadcast = true };
socket.Bind(localEndPoint);
}
public void Dispose()
{
if (Interlocked.Exchange(ref isDisposed, true)) return;
socket.Dispose();
}
public async Task ResolveAsync(CancellationToken cancellationToken)
{
var buffer = ArrayPool<byte>.Shared.Rent(4096);
try
{
buffer[0] = CLNT_BCAST_EX;
await socket.SendToAsync(buffer.AsMemory(0, 1), new IPEndPoint(IPAddress.Broadcast, 1434), cancellationToken);
while (!Volatile.Read(ref isDisposed))
{
var bytesReceived = await socket.ReceiveAsync(buffer, cancellationToken);
HandleResponse(buffer.AsSpan(0, bytesReceived));
}
}
catch (OperationCanceledException) { }
ArrayPool<byte>.Shared.Return(buffer);
}
private void HandleResponse(Span<byte> data)
{
var handler = InstancesDiscovered;
if (handler == null) return;
if (data is not [SVR_RESP, _, _, ..]) return;
var responseLength = BinaryPrimitives.ReadUInt16LittleEndian(data[1..]);
if (responseLength > data.Length - 3)
throw new InvalidDataException("Response did not fit in buffer.");
handler.Invoke(this, ParseRespData(data.Slice(3, responseLength)));
}
private static ImmutableArray<SqlServerInstance> ParseRespData(ReadOnlySpan<byte> respData)
{
var instances = ImmutableArray.CreateBuilder<SqlServerInstance>();
var reader = new RespDataReader(respData);
while (!reader.IsEndOfData)
{
var serverName = Encoding.UTF8.GetString(reader.ReadField("ServerName"u8));
var instanceName = Encoding.UTF8.GetString(reader.ReadField("InstanceName"u8));
_ = reader.ReadField("IsClustered"u8);
var version = Version.Parse(Encoding.UTF8.GetString(reader.ReadField("Version"u8)));
instances.Add(new SqlServerInstance
{
ServerName = serverName,
InstanceName = instanceName.Equals("MSSQLSERVER", StringComparison.OrdinalIgnoreCase) ? null : instanceName,
Version = version,
});
reader.ReadToEndOfInstance();
}
return instances.DrainToImmutable();
}
private ref struct RespDataReader(ReadOnlySpan<byte> respData)
{
private ReadOnlySpan<byte> remaining = respData;
public bool IsEndOfData => remaining.IsEmpty;
public ReadOnlySpan<byte> ReadField(ReadOnlySpan<byte> expectedName)
{
if (!remaining.StartsWith(expectedName) || !remaining[expectedName.Length..].StartsWith((byte)';'))
throw new InvalidDataException($"Expected \"{Encoding.UTF8.GetString(expectedName)};\" but found: {Encoding.UTF8.GetString(remaining)}");
remaining = remaining[(expectedName.Length + 1)..];
var fieldEndIndex = remaining.IndexOf((byte)';');
if (fieldEndIndex == -1)
throw new InvalidDataException($"Expected semicolon following \"{Encoding.UTF8.GetString(expectedName)};\" but found: {Encoding.UTF8.GetString(remaining)}");
var fieldValue = remaining[..fieldEndIndex];
remaining = remaining[(fieldEndIndex + 1)..];
return fieldValue;
}
public void ReadToEndOfInstance()
{
while (true)
{
var nextSemicolonIndex = remaining.IndexOf((byte)';');
if (nextSemicolonIndex == -1)
throw new InvalidDataException($"Expected semicolon at end of instance, but found: {Encoding.UTF8.GetString(remaining)}");
remaining = remaining[(nextSemicolonIndex + 1)..];
if (nextSemicolonIndex == 0)
{
// The last thing that was read would have been a semicolon, and here we are again at a semicolon
// with no field name. This must be the end.
break;
}
// Otherwise, we read past a field name and its semicolon.
nextSemicolonIndex = remaining.IndexOf((byte)';');
if (nextSemicolonIndex == -1)
throw new InvalidDataException($"Expected field value and semicolon following field name and semicolon, but found: {Encoding.UTF8.GetString(remaining)}");
remaining = remaining[(nextSemicolonIndex + 1)..];
}
}
}
}
}
public sealed class SqlServerInstance
{
public required string ServerName { get; init; }
public required string? InstanceName { get; init; }
public required Version Version { get; init; }
public string DataSourceName => InstanceName is null ? ServerName : $@"{ServerName}\{InstanceName}";
public override string ToString() => $"{DataSourceName} ({Version})";
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment