Skip to content

Instantly share code, notes, and snippets.

@MihaZupan
Last active June 25, 2020 20:05
Show Gist options
  • Save MihaZupan/882eac6ec582ef07dc74fba014011102 to your computer and use it in GitHub Desktop.
Save MihaZupan/882eac6ec582ef07dc74fba014011102 to your computer and use it in GitHub Desktop.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using System.Diagnostics;
using Microsoft.Extensions.Internal;
namespace System.Net
{
internal static partial class NameResolutionPal
{
private static volatile bool s_initialized;
private static readonly object s_initializedLock = new object();
private static readonly unsafe Interop.Winsock.LPLOOKUPSERVICE_COMPLETION_ROUTINE s_getAddrInfoExCallback = GetAddressInfoExCallback;
private static bool s_getAddrInfoExSupported;
public static void EnsureSocketsAreInitialized()
{
if (!s_initialized)
{
InitializeSockets();
}
static void InitializeSockets()
{
lock (s_initializedLock)
{
if (!s_initialized)
{
SocketError errorCode = Interop.Winsock.WSAStartup();
if (errorCode != SocketError.Success)
{
// WSAStartup does not set LastWin32Error
throw new SocketException((int)errorCode);
}
s_getAddrInfoExSupported = GetAddrInfoExSupportsOverlapped();
s_initialized = true;
}
}
}
}
public static bool SupportsGetAddrInfoAsync
{
get
{
EnsureSocketsAreInitialized();
return s_getAddrInfoExSupported;
}
}
public static unsafe SocketError TryGetAddrInfo(string name, bool justAddresses, out string? hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode)
{
aliases = Array.Empty<string>();
var hints = new Interop.Winsock.AddressInfo { ai_family = AddressFamily.Unspecified }; // Gets all address families
if (!justAddresses)
{
hints.ai_flags = AddressInfoHints.AI_CANONNAME;
}
Interop.Winsock.AddressInfo* result = null;
try
{
SocketError errorCode = (SocketError)Interop.Winsock.GetAddrInfoW(name, null, &hints, &result);
if (errorCode != SocketError.Success)
{
nativeErrorCode = (int)errorCode;
hostName = name;
addresses = Array.Empty<IPAddress>();
return errorCode;
}
addresses = ParseAddressInfo(result, justAddresses, out hostName);
nativeErrorCode = 0;
return SocketError.Success;
}
finally
{
if (result != null)
{
Interop.Winsock.FreeAddrInfoW(result);
}
}
}
public static unsafe string? TryGetNameInfo(IPAddress addr, out SocketError errorCode, out int nativeErrorCode)
{
SocketAddress address = new IPEndPoint(addr, 0).Serialize();
Span<byte> addressBuffer = address.Size <= 64 ? stackalloc byte[64] : new byte[address.Size];
for (int i = 0; i < address.Size; i++)
{
addressBuffer[i] = address[i];
}
const int NI_MAXHOST = 1025;
char* hostname = stackalloc char[NI_MAXHOST];
fixed (byte* addressBufferPtr = addressBuffer)
{
errorCode = Interop.Winsock.GetNameInfoW(
addressBufferPtr,
address.Size,
hostname,
NI_MAXHOST,
null, // We don't want a service name
0, // so no need for buffer or length
(int)Interop.Winsock.NameInfoFlags.NI_NAMEREQD);
}
if (errorCode == SocketError.Success)
{
nativeErrorCode = 0;
return new string(hostname);
}
nativeErrorCode = (int)errorCode;
return null;
}
public static unsafe string GetHostName()
{
// We do not cache the result in case the hostname changes.
const int HostNameBufferLength = 256;
byte* buffer = stackalloc byte[HostNameBufferLength];
SocketError result = Interop.Winsock.gethostname(buffer, HostNameBufferLength);
if (result != SocketError.Success)
{
if (NetEventSource.IsEnabled) NetEventSource.Error(null, $"GetHostName failed with {result}");
throw new SocketException();
}
return new string((sbyte*)buffer);
}
public static unsafe Task GetAddrInfoAsync(string hostName, bool justAddresses)
{
ValueStopwatch stopwatch = NameResolutionTelemetry.Log.ResolutionStart(hostName);
GetAddrInfoExContext* context = GetAddrInfoExContext.AllocateContext();
GetAddrInfoExState state;
try
{
state = new GetAddrInfoExState(hostName, justAddresses, stopwatch);
context->QueryStateHandle = state.CreateHandle();
}
catch
{
GetAddrInfoExContext.FreeContext(context);
NameResolutionTelemetry.Log.AfterResolution(hostName, stopwatch, successful: false);
throw;
}
var hints = new Interop.Winsock.AddressInfoEx { ai_family = AddressFamily.Unspecified }; // Gets all address families
if (!justAddresses)
{
hints.ai_flags = AddressInfoHints.AI_CANONNAME;
}
SocketError errorCode = (SocketError)Interop.Winsock.GetAddrInfoExW(
hostName, null, Interop.Winsock.NS_ALL, IntPtr.Zero, &hints, &context->Result, IntPtr.Zero, &context->Overlapped, s_getAddrInfoExCallback, &context->CancelHandle);
if (errorCode != SocketError.IOPending)
{
ProcessResult(errorCode, context);
}
return state.Task;
}
private static unsafe void GetAddressInfoExCallback(int error, int bytes, NativeOverlapped* overlapped)
{
// Can be casted directly to GetAddrInfoExContext* because the overlapped is its first field
GetAddrInfoExContext* context = (GetAddrInfoExContext*)overlapped;
ProcessResult((SocketError)error, context);
}
private static unsafe void ProcessResult(SocketError errorCode, GetAddrInfoExContext* context)
{
try
{
GetAddrInfoExState state = GetAddrInfoExState.FromHandleAndFree(context->QueryStateHandle);
try
{
if (errorCode == SocketError.Success)
{
IPAddress[] addresses = ParseAddressInfoEx(context->Result, state.JustAddresses, out string? hostName);
state.SetResult(state.JustAddresses ? (object)
addresses :
new IPHostEntry
{
HostName = hostName ?? state.HostName,
Aliases = Array.Empty<string>(),
AddressList = addresses
});
}
else
{
state.SetResult(ExceptionDispatchInfo.SetCurrentStackTrace(new SocketException((int)errorCode)));
}
NameResolutionTelemetry.Log.AfterResolution(state.HostName, state.Stopwatch, successful: true);
}
catch when (Dns.LogFailure(state.HostName, state.Stopwatch))
{
Debug.Fail("LogFailure should return false");
throw;
}
}
finally
{
GetAddrInfoExContext.FreeContext(context);
}
}
private static unsafe IPAddress[] ParseAddressInfo(Interop.Winsock.AddressInfo* addressInfoPtr, bool justAddresses, out string? hostName)
{
Debug.Assert(addressInfoPtr != null);
// Count how many results we have.
int addressCount = 0;
for (Interop.Winsock.AddressInfo* result = addressInfoPtr; result != null; result = result->ai_next)
{
int addressLength = (int)result->ai_addrlen;
if (result->ai_family == AddressFamily.InterNetwork)
{
if (addressLength == SocketAddressPal.IPv4AddressSize)
{
addressCount++;
}
}
else if (SocketProtocolSupportPal.OSSupportsIPv6 && result->ai_family == AddressFamily.InterNetworkV6)
{
if (addressLength == SocketAddressPal.IPv6AddressSize)
{
addressCount++;
}
}
}
// Store them into the array.
var addresses = new IPAddress[addressCount];
addressCount = 0;
string? canonicalName = justAddresses ? "NONNULLSENTINEL" : null;
for (Interop.Winsock.AddressInfo* result = addressInfoPtr; result != null; result = result->ai_next)
{
if (canonicalName == null && result->ai_canonname != null)
{
canonicalName = Marshal.PtrToStringUni((IntPtr)result->ai_canonname);
}
int addressLength = (int)result->ai_addrlen;
var socketAddress = new ReadOnlySpan<byte>(result->ai_addr, addressLength);
if (result->ai_family == AddressFamily.InterNetwork)
{
if (addressLength == SocketAddressPal.IPv4AddressSize)
{
addresses[addressCount++] = CreateIPv4Address(socketAddress);
}
}
else if (SocketProtocolSupportPal.OSSupportsIPv6 && result->ai_family == AddressFamily.InterNetworkV6)
{
if (addressLength == SocketAddressPal.IPv6AddressSize)
{
addresses[addressCount++] = CreateIPv6Address(socketAddress);
}
}
}
hostName = justAddresses ? null : canonicalName;
return addresses;
}
private static unsafe IPAddress[] ParseAddressInfoEx(Interop.Winsock.AddressInfoEx* addressInfoExPtr, bool justAddresses, out string? hostName)
{
Debug.Assert(addressInfoExPtr != null);
// First count how many address results we have.
int addressCount = 0;
for (Interop.Winsock.AddressInfoEx* result = addressInfoExPtr; result != null; result = result->ai_next)
{
int addressLength = (int)result->ai_addrlen;
if (result->ai_family == AddressFamily.InterNetwork)
{
if (addressLength == SocketAddressPal.IPv4AddressSize)
{
addressCount++;
}
}
else if (SocketProtocolSupportPal.OSSupportsIPv6 && result->ai_family == AddressFamily.InterNetworkV6)
{
if (addressLength == SocketAddressPal.IPv6AddressSize)
{
addressCount++;
}
}
}
// Then store them into an array.
var addresses = new IPAddress[addressCount];
addressCount = 0;
string? canonicalName = justAddresses ? "NONNULLSENTINEL" : null;
for (Interop.Winsock.AddressInfoEx* result = addressInfoExPtr; result != null; result = result->ai_next)
{
if (canonicalName == null && result->ai_canonname != IntPtr.Zero)
{
canonicalName = Marshal.PtrToStringUni(result->ai_canonname);
}
int addressLength = (int)result->ai_addrlen;
var socketAddress = new ReadOnlySpan<byte>(result->ai_addr, addressLength);
if (result->ai_family == AddressFamily.InterNetwork)
{
if (addressLength == SocketAddressPal.IPv4AddressSize)
{
addresses[addressCount++] = CreateIPv4Address(socketAddress);
}
}
else if (SocketProtocolSupportPal.OSSupportsIPv6 && result->ai_family == AddressFamily.InterNetworkV6)
{
if (addressLength == SocketAddressPal.IPv6AddressSize)
{
addresses[addressCount++] = CreateIPv6Address(socketAddress);
}
}
}
// Return the parsed host name (if we got one) and addresses.
hostName = justAddresses ? null : canonicalName;
return addresses;
}
private static unsafe IPAddress CreateIPv4Address(ReadOnlySpan<byte> socketAddress)
{
long address = (long)SocketAddressPal.GetIPv4Address(socketAddress) & 0x0FFFFFFFF;
return new IPAddress(address);
}
private static unsafe IPAddress CreateIPv6Address(ReadOnlySpan<byte> socketAddress)
{
Span<byte> address = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes];
SocketAddressPal.GetIPv6Address(socketAddress, address, out uint scope);
return new IPAddress(address, scope);
}
private sealed class GetAddrInfoExState : IThreadPoolWorkItem
{
private AsyncTaskMethodBuilder<IPHostEntry> IPHostEntryBuilder;
private AsyncTaskMethodBuilder<IPAddress[]> IPAddressArrayBuilder;
private object? _result;
public GetAddrInfoExState(string hostName, bool justAddresses, ValueStopwatch stopwatch)
{
HostName = hostName;
JustAddresses = justAddresses;
Stopwatch = stopwatch;
if (justAddresses)
{
IPAddressArrayBuilder = AsyncTaskMethodBuilder<IPAddress[]>.Create();
_ = IPAddressArrayBuilder.Task; // force initialization
}
else
{
IPHostEntryBuilder = AsyncTaskMethodBuilder<IPHostEntry>.Create();
_ = IPHostEntryBuilder.Task; // force initialization
}
}
public string HostName { get; }
public bool JustAddresses { get; }
public ValueStopwatch Stopwatch { get; }
public Task Task => JustAddresses ? (Task)IPAddressArrayBuilder.Task : IPHostEntryBuilder.Task;
public void SetResult(object result)
{
// Store the result and then queue this object to the thread pool to actually complete the Tasks, as we
// want to avoid invoking continuations on the Windows callback thread. Effectively we're manually
// implementing TaskCreationOptions.RunContinuationsAsynchronously, which we can't use because we're
// using AsyncTaskMethodBuilder, which we're using in order to create either a strongly-typed Task<IPHostEntry>
// or Task<IPAddress[]> without allocating additional objects.
Debug.Assert(result is Exception || result is IPAddress[] || result is IPHostEntry);
_result = result;
ThreadPool.UnsafeQueueUserWorkItem(this, preferLocal: false);
}
void IThreadPoolWorkItem.Execute()
{
if (JustAddresses)
{
if (_result is Exception e)
{
IPAddressArrayBuilder.SetException(e);
}
else
{
IPAddressArrayBuilder.SetResult((IPAddress[])_result!);
}
}
else
{
if (_result is Exception e)
{
IPHostEntryBuilder.SetException(e);
}
else
{
IPHostEntryBuilder.SetResult((IPHostEntry)_result!);
}
}
}
public IntPtr CreateHandle() => GCHandle.ToIntPtr(GCHandle.Alloc(this, GCHandleType.Normal));
public static GetAddrInfoExState FromHandleAndFree(IntPtr handle)
{
GCHandle gcHandle = GCHandle.FromIntPtr(handle);
var state = (GetAddrInfoExState)gcHandle.Target!;
gcHandle.Free();
return state;
}
}
[StructLayout(LayoutKind.Sequential)]
private unsafe struct GetAddrInfoExContext
{
public NativeOverlapped Overlapped;
public Interop.Winsock.AddressInfoEx* Result;
public IntPtr CancelHandle;
public IntPtr QueryStateHandle;
public static GetAddrInfoExContext* AllocateContext()
{
var context = (GetAddrInfoExContext*)Marshal.AllocHGlobal(sizeof(GetAddrInfoExContext));
*context = default;
return context;
}
public static void FreeContext(GetAddrInfoExContext* context)
{
if (context->Result != null)
{
Interop.Winsock.FreeAddrInfoExW(context->Result);
}
Marshal.FreeHGlobal((IntPtr)context);
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment