|
package main |
|
|
|
import ( |
|
"net" |
|
"net/http/httputil" |
|
"regexp" |
|
"strings" |
|
|
|
"github.com/samber/lo" |
|
"github.com/theckman/httpforwarded" |
|
) |
|
|
|
// Configuration |
|
|
|
type ConfigForwarded struct { |
|
SetOutbound ConfigEnabledForwardedHeaders |
|
PreserveInbound ConfigEnabledForwardedHeaders |
|
} |
|
|
|
type ConfigEnabledForwardedHeaders struct { |
|
Forwarded bool |
|
XForwarded bool |
|
XRealIP bool |
|
} |
|
|
|
func (h ConfigEnabledForwardedHeaders) Any() bool { |
|
return h.Forwarded || h.XForwarded || h.XRealIP |
|
} |
|
|
|
// Replacement for ProxyRequest.SetXForwarded() |
|
|
|
func SetForwardedHeaders(r *httputil.ProxyRequest, cfg ConfigForwarded) { |
|
// Cleanup proxy forward headers. |
|
// Some of them are automatically copied from the request, |
|
// and we don't want it |
|
r.Out.Header.Del("Forwarded") |
|
r.Out.Header.Del("X-Forwarded-Host") |
|
r.Out.Header.Del("X-Forwarded-Proto") |
|
r.Out.Header.Del("X-Forwarded-For") |
|
r.Out.Header.Del("X-Real-IP") |
|
|
|
if !cfg.SetOutbound.Any() { |
|
return |
|
} |
|
|
|
var forwardedFor []string |
|
var forwardedBy []string |
|
|
|
var forwardedHost string |
|
forwardedHost = r.In.Host |
|
|
|
var forwardedProto string |
|
if r.In.TLS == nil { |
|
forwardedProto = "http" |
|
} else { |
|
forwardedProto = "https" |
|
} |
|
|
|
// Preserve data from inbound headers |
|
if cfg.PreserveInbound.Any() { |
|
forwardedHeader, err := httpforwarded.Parse(r.In.Header.Values("Forwarded")) |
|
if err != nil { |
|
// log.Warning("malformed Forwarded header: %w", err) |
|
forwardedHeader = map[string][]string{} |
|
} |
|
xForwardHost := r.In.Header.Values("X-Forwarded-Host") |
|
xForwardProto := r.In.Header.Values("X-Forwarded-Proto") |
|
xForwardFor := r.In.Header.Values("X-Forwarded-For") |
|
xRealIp := r.In.Header.Values("X-Real-IP") |
|
|
|
// "for": retrieve from either "Forwarded", "X-Forwarded-For" or "X-Real-IP" |
|
// TODO: Fix conflicts using topologic sort |
|
switch { |
|
case cfg.PreserveInbound.Forwarded && len(forwardedHeader["for"]) > 0: |
|
forwardedFor = forwardedHeader["for"] |
|
case cfg.PreserveInbound.XForwarded && len(xForwardFor) > 0: |
|
forwardedFor = xForwardFor |
|
case cfg.PreserveInbound.XRealIP && len(xRealIp) > 0: |
|
forwardedFor = []string{xRealIp[0]} |
|
} |
|
|
|
// "by": There is no X- equivalent |
|
switch { |
|
case cfg.PreserveInbound.Forwarded && len(forwardedHeader["by"]) > 0: |
|
forwardedBy = forwardedHeader["by"] |
|
} |
|
|
|
// host: retrieve from either "Forwarded" or "X-Forwarded-Host" |
|
switch { |
|
case cfg.PreserveInbound.Forwarded && len(forwardedHeader["host"]) > 0: |
|
forwardedHost = forwardedHeader["host"][0] |
|
case cfg.PreserveInbound.XForwarded && len(xForwardHost) > 0: |
|
forwardedHost = xForwardHost[0] |
|
} |
|
|
|
// host: retrieve from either "Forwarded" or "X-Forwarded-Proto" |
|
switch { |
|
case cfg.PreserveInbound.Forwarded && len(forwardedHeader["proto"]) > 0: |
|
forwardedProto = forwardedHeader["proto"][0] |
|
case cfg.PreserveInbound.XForwarded && len(xForwardProto) > 0: |
|
forwardedProto = xForwardProto[0] |
|
} |
|
} |
|
|
|
clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr) // discard port on client-side |
|
if err != nil { |
|
clientIP = r.In.RemoteAddr |
|
} |
|
forwardedFor = append(forwardedFor, clientIP) |
|
|
|
serverIP := r.Out.RemoteAddr // Preserve port in server-side |
|
forwardedBy = append(forwardedBy, serverIP) |
|
|
|
if cfg.SetOutbound.Forwarded { |
|
r.Out.Header.Set("Forwarded", httpforwarded.Format(map[string][]string{ |
|
"for": lo.Map(forwardedFor, func(value string, _ int) string { |
|
return formatNodeIdentifier(value, true) |
|
}), |
|
"by": lo.Map(forwardedBy, func(value string, _ int) string { |
|
return formatNodeIdentifier(value, true) |
|
}), |
|
"host": {forwardedHost}, |
|
"proto": {forwardedProto}, |
|
})) |
|
} |
|
if cfg.SetOutbound.XForwarded { |
|
r.Out.Header.Set("X-Forwarded-Host", forwardedHost) |
|
r.Out.Header.Set("X-Forwarded-Proto", forwardedProto) |
|
r.Out.Header.Del("X-Forwarded-For") |
|
for _, v := range forwardedFor { |
|
r.Out.Header.Add("X-Forwarded-For", formatNodeIdentifier(v, false)) |
|
} |
|
} |
|
if cfg.SetOutbound.XRealIP { |
|
r.Out.Header.Set("X-Real-IP", formatNodeIdentifier(forwardedFor[0], false)) |
|
} |
|
} |
|
|
|
const LegacyUnknownHopIP = "0.0.0.0" |
|
|
|
var numericPortRegex = regexp.MustCompile(`^\d{1,5}$`) |
|
var obfuscatedIdentifierRegex = regexp.MustCompile(`^_[a-zA-Z0-9._-]+$`) |
|
|
|
func formatNodeIdentifier(value string, newSyntax bool) string { |
|
// See https: //www.rfc-editor.org/rfc/rfc7239.html#section-6 |
|
|
|
// IPv4, IPv6 |
|
prettyIP := formatIP(value, newSyntax) |
|
if prettyIP != "" { |
|
return prettyIP |
|
} |
|
// IPv4:port, [IPv6]:port |
|
collonPos := strings.LastIndex(value, ":") |
|
if collonPos >= 0 { |
|
ip := value[0:collonPos] |
|
port := value[collonPos+1:] |
|
// port can be either a numeric value or an _obfuscate_identifier |
|
if !numericPortRegex.MatchString(port) && !obfuscatedIdentifierRegex.MatchString(port) { |
|
port = "" |
|
} |
|
prettyIP := formatIP(ip, newSyntax) |
|
|
|
if prettyIP != "" { |
|
if newSyntax && port != "" { |
|
return prettyIP + ":" + port |
|
} else { |
|
return prettyIP |
|
} |
|
} |
|
} |
|
|
|
if !newSyntax { |
|
return LegacyUnknownHopIP |
|
} |
|
|
|
if obfuscatedIdentifierRegex.MatchString(value) { |
|
// _obfuscated_value |
|
return value |
|
} else { |
|
// Unknown / unparseable |
|
return "unknown" |
|
} |
|
} |
|
|
|
func formatIP(value string, newSyntax bool) string { |
|
rawIp := net.ParseIP(value) |
|
if rawIp == nil { |
|
// Stupid way to transform [IPv6] -> IPv6 |
|
host, _, err := net.SplitHostPort(value + ":0") |
|
if err != nil { |
|
return "" |
|
} |
|
rawIp = net.ParseIP(host) |
|
if rawIp == nil { |
|
return "" |
|
} |
|
} |
|
ret := rawIp.String() |
|
if newSyntax { |
|
if ret == LegacyUnknownHopIP { |
|
return "unknown" |
|
} |
|
if strings.Contains(ret, ":") { |
|
// This is an IPv6 and must be enclosed in braces |
|
ret = "[" + ret + "]" |
|
} |
|
} |
|
return ret |
|
} |
|
|
|
func main() {} |