Skip to content

Instantly share code, notes, and snippets.

@stefansundin
Created November 12, 2024 03:32
Show Gist options
  • Save stefansundin/ed2b2e0e9c84b22f3cc962138a423549 to your computer and use it in GitHub Desktop.
Save stefansundin/ed2b2e0e9c84b22f3cc962138a423549 to your computer and use it in GitHub Desktop.
Golang dialer control. Block network requests to undesired IP ranges (e.g. private IP ranges).
package main
import (
"errors"
"fmt"
"net"
"net/http"
"os"
"syscall"
)
var privateIPBlocks []*net.IPNet
func init() {
for _, cidr := range []string{
"127.0.0.0/8", // IPv4 loopback
"10.0.0.0/8", // RFC1918
"172.16.0.0/12", // RFC1918
"192.168.0.0/16", // RFC1918
"169.254.0.0/16", // RFC3927 link-local
"::1/128", // IPv6 loopback
"fe80::/10", // IPv6 link-local
"fc00::/7", // IPv6 unique local addr
} {
_, block, err := net.ParseCIDR(cidr)
if err != nil {
panic(fmt.Errorf("parse error on %q: %v", cidr, err))
}
privateIPBlocks = append(privateIPBlocks, block)
}
}
func isPrivateIP(ip net.IP) bool {
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
for _, block := range privateIPBlocks {
if block.Contains(ip) {
return true
}
}
return false
}
func protectedDialerControl(network, address string, c syscall.RawConn) error {
host, port, err := net.SplitHostPort(address)
if err != nil {
return err
}
ip := net.ParseIP(host)
if isPrivateIP(ip) {
return errors.New("private IP address not allowed.")
}
if port != "80" && port != "443" {
return errors.New("nonstandard port not allowed.")
}
return nil
}
func main() {
var untrustedUrl string
// untrustedUrl = "http://--1.sslip.io/" // Get "http://--1.sslip.io/": dial tcp [::1]:80: private IP address not allowed.
// untrustedUrl = "http://www.192.168.0.1.sslip.io/" // Get "http://www.192.168.0.1.sslip.io/": dial tcp 192.168.0.1:80: private IP address not allowed.
// untrustedUrl = "http://127.0.0.1.sslip.io/" // Get "http://127.0.0.1.sslip.io/": dial tcp 127.0.0.1:80: private IP address not allowed.
// untrustedUrl = "https://wikipedia.org:8080/" // Get "https://wikipedia.org:8080/": dial tcp [2620:0:863:ed1a::1]:8080: nonstandard port not allowed.
untrustedUrl = "https://wikipedia.org/" // The only URL that works. 200 OK
fmt.Println(untrustedUrl)
protectedTransport := &http.Transport{
DialContext: (&net.Dialer{
Control: protectedDialerControl,
}).DialContext,
}
client := &http.Client{
Transport: protectedTransport,
}
resp, err := client.Get(untrustedUrl)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return
}
fmt.Println(resp.Status)
}
package main
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"os"
"syscall"
"time"
)
var privateIPBlocks []*net.IPNet
func init() {
for _, cidr := range []string{
"127.0.0.0/8", // IPv4 loopback
"10.0.0.0/8", // RFC1918
"172.16.0.0/12", // RFC1918
"192.168.0.0/16", // RFC1918
"169.254.0.0/16", // RFC3927 link-local
"::1/128", // IPv6 loopback
"fe80::/10", // IPv6 link-local
"fc00::/7", // IPv6 unique local addr
} {
_, block, err := net.ParseCIDR(cidr)
if err != nil {
panic(fmt.Errorf("parse error on %q: %v", cidr, err))
}
privateIPBlocks = append(privateIPBlocks, block)
}
}
func isPrivateIP(ip net.IP) bool {
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
for _, block := range privateIPBlocks {
if block.Contains(ip) {
return true
}
}
return false
}
func protectedDialerControl(network, address string, c syscall.RawConn) error {
host, port, err := net.SplitHostPort(address)
if err != nil {
return err
}
ip := net.ParseIP(host)
if isPrivateIP(ip) {
return errors.New("private IP address not allowed.")
}
if port != "80" && port != "443" {
return errors.New("nonstandard port not allowed.")
}
return nil
}
func main() {
var untrustedUrl string
// untrustedUrl = "http://random-domain.com/"
// untrustedUrl = "http://--1.sslip.io/"
untrustedUrl = "http://www.192.168.0.1.sslip.io/"
// untrustedUrl = "http://127.0.0.1.sslip.io:8080/"
// untrustedUrl = "https://sslip.io:8080/"
fmt.Println(untrustedUrl)
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
fmt.Println(address)
d := net.Dialer{
Timeout: time.Millisecond * time.Duration(10000),
}
return d.DialContext(ctx, network, "8.8.8.8:53")
},
}
protectedTransport := &http.Transport{
DialContext: (&net.Dialer{
Resolver: resolver,
Control: protectedDialerControl,
}).DialContext,
}
client := &http.Client{
Transport: protectedTransport,
}
resp, err := client.Get(untrustedUrl)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return
}
fmt.Println(resp)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment