Last active
October 3, 2023 07:12
-
-
Save Zenithar/fd6fd9ca12abba63abfc4519c7855e5c to your computer and use it in GitHub Desktop.
Pinned TLS Dialer for Go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"crypto/sha256" | |
"crypto/subtle" | |
"crypto/tls" | |
"crypto/x509" | |
"crypto/x509/pkix" | |
"encoding/asn1" | |
"encoding/hex" | |
"errors" | |
"fmt" | |
"net" | |
"net/http" | |
) | |
const ( | |
// Specifies the maximum allowed length of the certificate chain in TLS | |
// handshaking. | |
maxCertificateCount = 25 | |
) | |
var ( | |
// ErrNoPinMatch is raised when certificate fingerprints doesn't match the | |
// given fingerprint. | |
ErrNoPinMatch = errors.New("no certificate match the expected fingerprint") | |
// ErrCertificateChainTooLong is raised when the certificate chain returned | |
// by the TLS handshake is too large. | |
ErrCertificateChainTooLong = fmt.Errorf("the certificate chain exceeds the maximum allowed length (%d)", maxCertificateCount) | |
) | |
// Dialer represents network dialer function for mocking purpose. | |
type Dialer func(ctx context.Context, network, addr string) (net.Conn, error) | |
// PinnedDialer uses the given tlsconfig configuration to establish an initial | |
// connection with the remote peer, and validate the certificate public key | |
// fingerprint against the given fingerprint. | |
// | |
// Use this dialer to ensure a remote peer certificate. This helps to mitigate | |
// DNS based attacks or SSL/TLS MiTM which could be used to reroute/proxy TLS traffic | |
// through an unauthorized peer, and drive the risk to total confidentiality compromise. | |
func PinnedDialer(cfg *tls.Config, fingerPrint []byte) Dialer { | |
return func(ctx context.Context, network, addr string) (net.Conn, error) { | |
// Check argument | |
if cfg == nil { | |
return nil, errors.New("bootstrap TLS configuration must be provided") | |
} | |
// Clone the given configuration | |
clientConfig := cfg.Clone() | |
// Try to connect to the remote server first to retrieve certificates. | |
c, err := tls.Dial(network, addr, clientConfig) | |
if err != nil { | |
return nil, fmt.Errorf("unable to establish initial TLS connection to retrieve certificates: %w", err) | |
} | |
connState := c.ConnectionState() | |
keyPinValid := false | |
// Ensure acceptable certificate count | |
if len(connState.PeerCertificates) > maxCertificateCount { | |
return nil, ErrCertificateChainTooLong | |
} | |
// Iterate over all returned certificates | |
for _, peerCert := range connState.PeerCertificates { | |
// Check if context has error to stop the validation prematurely. | |
if err := ctx.Err(); err != nil { | |
return nil, err | |
} | |
// Compute public key certificate fingerprint | |
hash, err := PublicKeyFingerprint(peerCert) | |
if err != nil { | |
return c, fmt.Errorf("unable to compute public key fingerprint: %w", err) | |
} | |
// Check equality whith provided fingerprint | |
if subtle.ConstantTimeCompare(hash, fingerPrint) == 1 { | |
keyPinValid = true | |
} | |
// Continue to process all certificates | |
} | |
if !keyPinValid { | |
return nil, ErrNoPinMatch | |
} | |
return c, nil | |
} | |
} | |
// subjectPublicKeyInfo is a PKIX public key structure defined in RFC 5280. | |
type subjectPublicKeyInfo struct { | |
Algorithm pkix.AlgorithmIdentifier | |
SubjectPublicKey asn1.BitString | |
} | |
// PublicKeyFingerprint generates a public key fingerprint. | |
// https://www.rfc-editor.org/rfc/rfc6698 | |
// | |
// This fingerprint algorithm marshal the public key using PKIX ASN.1 to DER | |
// content. The ASN.1 is processed to retrieve the SubjectPublicKey content from | |
// the ASN.1 serialized and compute the SHA256 of the SubjectPublicKey content. | |
func PublicKeyFingerprint(cert *x509.Certificate) ([]byte, error) { | |
// Check argument | |
if cert == nil { | |
return nil, errors.New("a non-nil certificate must be provided") | |
} | |
// Marshal the public key as DER | |
out, err := x509.MarshalPKIXPublicKey(cert.PublicKey) | |
if err != nil { | |
return nil, fmt.Errorf("unable to serialize key: %w", err) | |
} | |
// Extract unwrapped public key content. | |
var info subjectPublicKeyInfo | |
if _, err = asn1.Unmarshal(out, &info); err != nil { | |
return nil, fmt.Errorf("unable to extract DER content from the encoded public key: %w", err) | |
} | |
// Compute SHA256 checksum of the public key. | |
h := sha256.Sum256(info.SubjectPublicKey.Bytes) | |
return h[:], nil | |
} | |
func main() { | |
resp, err := http.DefaultClient.Get("https://www.google.com") | |
if err != nil { | |
panic(err) | |
} | |
// Create all fingerprint from the certificate chain. | |
for _, peerCert := range resp.TLS.PeerCertificates { | |
h, err := PublicKeyFingerprint(peerCert) | |
if err != nil { | |
panic(err) | |
} | |
fmt.Println(hex.EncodeToString(h)) | |
} | |
// Extracted from the previous request. | |
fgr, _ := hex.DecodeString("94af08ac6bbe62bddb9ee8839f18b991290691c0b35db2651b58d98b6a4bea38") | |
// HTTP client with pinned dialer to enforce remote certificate check. | |
pinnedClient := &http.Client{ | |
Transport: &http.Transport{ | |
DialTLSContext: PinnedDialer( | |
// TLS Configuration used to establish initial connection to | |
// retrieve certificate chain. | |
&tls.Config{InsecureSkipVerify: true}, | |
// Expected fingerprint, raise an error if no certificate match | |
// the given fingerprint. | |
fgr, | |
), | |
}, | |
} | |
if _, err = pinnedClient.Get("https://www.google.com"); err != nil { | |
panic(err) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment