Skip to content

Instantly share code, notes, and snippets.

@schmitch
Created January 31, 2025 11:56
Show Gist options
  • Save schmitch/3dc25d2de0a1c88c2b1452bc8173cbc8 to your computer and use it in GitHub Desktop.
Save schmitch/3dc25d2de0a1c88c2b1452bc8173cbc8 to your computer and use it in GitHub Desktop.
dotnet core requirelocalport
using EnvisiaCms.Simple.Shared;
using Microsoft.AspNetCore.Routing.Matching;
namespace EnvisiaCms.Simple;
internal sealed class LocalPortMatcherPolicy : MatcherPolicy,
IEndpointComparerPolicy,
INodeBuilderPolicy
{
public override int Order { get; } = -100;
public bool AppliesToEndpoints(IReadOnlyList<Endpoint> endpoints)
{
return endpoints.Any(e => e.Metadata.GetMetadata<LocalPortMetadata>() != null);
}
public IReadOnlyList<PolicyNodeEdge> GetEdges(IReadOnlyList<Endpoint> endpoints)
{
ArgumentNullException.ThrowIfNull(endpoints);
// The algorithm here is designed to be preserve the order of the endpoints
// while also being relatively simple. Preserving order is important.
// First, build a dictionary of all of the content-type patterns that are included
// at this node.
//
// For now we're just building up the set of keys. We don't add any endpoints
// to lists now because we don't want ordering problems.
var edges = new Dictionary<int, List<Endpoint>>();
for (var i = 0; i < endpoints.Count; i++)
{
var endpoint = endpoints[i];
var localPort = endpoint.Metadata.GetMetadata<LocalPortMetadata>()?.Port;
if (localPort != null)
{
edges.TryAdd(localPort.Value, []);
}
}
// Now in a second loop, add endpoints to these lists. We've enumerated all of
// the states, so we want to see which states this endpoint matches.
foreach (var endpoint in endpoints)
{
var localPort = endpoint.Metadata.GetMetadata<LocalPortMetadata>()?.Port;
if (localPort == null)
{
// OK this means that this endpoint matches *all* content methods.
// So, loop and add it to all states.
foreach (var kvp in edges)
{
kvp.Value.Add(endpoint);
}
}
else
{
foreach (var kvp in edges)
{
// The edgeKey maps to a possible request header value
var edgeKey = kvp.Key;
if (edgeKey == localPort.Value)
{
kvp.Value.Add(endpoint);
break;
}
}
}
}
var result = new PolicyNodeEdge[edges.Count];
var index = 0;
foreach (var kvp in edges)
{
result[index] = new PolicyNodeEdge(kvp.Key, kvp.Value);
index++;
}
return result;
}
public PolicyJumpTable BuildJumpTable(int exitDestination, IReadOnlyList<PolicyJumpTableEdge> edges)
{
ArgumentNullException.ThrowIfNull(edges);
var ordered = new (int localPort, int destination)[edges.Count];
for (var i = 0; i < edges.Count; i++)
{
var e = edges[i];
var localPortState = (int?)e.State;
var localPort = localPortState ?? 0;
ordered[i] = (localPort: localPort, destination: e.Destination);
}
return new LocalPortPolicyJumpTable(exitDestination, ordered);
}
private sealed class LocalPortPolicyJumpTable : PolicyJumpTable
{
private readonly (int localPort, int destination)[] _destinations;
private readonly int _exitDestination;
public LocalPortPolicyJumpTable(int exitDestination, (int localPort, int destination)[] destinations)
{
_exitDestination = exitDestination;
_destinations = destinations;
}
public override int GetDestination(HttpContext httpContext)
{
var localPort = httpContext.Connection.LocalPort;
var destinations = _destinations;
for (var i = 0; i < destinations.Length; i++)
{
var destination = destinations[i];
if (destination.localPort == localPort)
{
return destinations[i].destination;
}
}
return _exitDestination;
}
}
bool INodeBuilderPolicy.AppliesToEndpoints(IReadOnlyList<Endpoint> endpoints)
{
if (ContainsDynamicEndpoints(endpoints))
{
return false;
}
return AppliesToEndpointsCore(endpoints);
}
private static bool AppliesToEndpointsCore(IReadOnlyList<Endpoint> endpoints)
{
return endpoints.Any(e => e.Metadata.GetMetadata<LocalPortMetadata>()?.Port != null);
}
public IComparer<Endpoint> Comparer { get; } = new LocalPortMetadataComparer();
private sealed class LocalPortMetadataComparer : EndpointMetadataComparer<LocalPortMetadata>
{
protected override int CompareMetadata(LocalPortMetadata? x, LocalPortMetadata? y)
{
return base.CompareMetadata(x?.Port != null ? x : null, y?.Port != null ? y : null);
}
}
}
namespace EnvisiaCms.Simple.Shared;
public record LocalPortMetadata(int Port);
using Microsoft.AspNetCore.Builder;
namespace EnvisiaCms.Simple.Shared;
public static class LocalPortRouteExtensions
{
public static TBuilder RequireLocalPort<TBuilder>(this TBuilder builder, int port) where TBuilder : IEndpointConventionBuilder
{
ArgumentNullException.ThrowIfNull(builder);
builder.Add(endpointBuilder =>
{
endpointBuilder.Metadata.Add(new LocalPortMetadata(port));
});
return builder;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment