Skip to content

Instantly share code, notes, and snippets.

@Zenithar
Last active October 3, 2023 07:12
Show Gist options
  • Save Zenithar/fd6fd9ca12abba63abfc4519c7855e5c to your computer and use it in GitHub Desktop.
Save Zenithar/fd6fd9ca12abba63abfc4519c7855e5c to your computer and use it in GitHub Desktop.
Pinned TLS Dialer for Go
package main
import (
"context"
"crypto/sha256"
"crypto/subtle"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/hex"
"errors"
"fmt"
"net"
"net/http"
)
const (
// Specifies the maximum allowed length of the certificate chain in TLS
// handshaking.
maxCertificateCount = 25
)
var (
// ErrNoPinMatch is raised when certificate fingerprints doesn't match the
// given fingerprint.
ErrNoPinMatch = errors.New("no certificate match the expected fingerprint")
// ErrCertificateChainTooLong is raised when the certificate chain returned
// by the TLS handshake is too large.
ErrCertificateChainTooLong = fmt.Errorf("the certificate chain exceeds the maximum allowed length (%d)", maxCertificateCount)
)
// Dialer represents network dialer function for mocking purpose.
type Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
// PinnedDialer uses the given tlsconfig configuration to establish an initial
// connection with the remote peer, and validate the certificate public key
// fingerprint against the given fingerprint.
//
// Use this dialer to ensure a remote peer certificate. This helps to mitigate
// DNS based attacks or SSL/TLS MiTM which could be used to reroute/proxy TLS traffic
// through an unauthorized peer, and drive the risk to total confidentiality compromise.
func PinnedDialer(cfg *tls.Config, fingerPrint []byte) Dialer {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
// Check argument
if cfg == nil {
return nil, errors.New("bootstrap TLS configuration must be provided")
}
// Clone the given configuration
clientConfig := cfg.Clone()
// Try to connect to the remote server first to retrieve certificates.
c, err := tls.Dial(network, addr, clientConfig)
if err != nil {
return nil, fmt.Errorf("unable to establish initial TLS connection to retrieve certificates: %w", err)
}
connState := c.ConnectionState()
keyPinValid := false
// Ensure acceptable certificate count
if len(connState.PeerCertificates) > maxCertificateCount {
return nil, ErrCertificateChainTooLong
}
// Iterate over all returned certificates
for _, peerCert := range connState.PeerCertificates {
// Check if context has error to stop the validation prematurely.
if err := ctx.Err(); err != nil {
return nil, err
}
// Compute public key certificate fingerprint
hash, err := PublicKeyFingerprint(peerCert)
if err != nil {
return c, fmt.Errorf("unable to compute public key fingerprint: %w", err)
}
// Check equality whith provided fingerprint
if subtle.ConstantTimeCompare(hash, fingerPrint) == 1 {
keyPinValid = true
}
// Continue to process all certificates
}
if !keyPinValid {
return nil, ErrNoPinMatch
}
return c, nil
}
}
// subjectPublicKeyInfo is a PKIX public key structure defined in RFC 5280.
type subjectPublicKeyInfo struct {
Algorithm pkix.AlgorithmIdentifier
SubjectPublicKey asn1.BitString
}
// PublicKeyFingerprint generates a public key fingerprint.
// https://www.rfc-editor.org/rfc/rfc6698
//
// This fingerprint algorithm marshal the public key using PKIX ASN.1 to DER
// content. The ASN.1 is processed to retrieve the SubjectPublicKey content from
// the ASN.1 serialized and compute the SHA256 of the SubjectPublicKey content.
func PublicKeyFingerprint(cert *x509.Certificate) ([]byte, error) {
// Check argument
if cert == nil {
return nil, errors.New("a non-nil certificate must be provided")
}
// Marshal the public key as DER
out, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
if err != nil {
return nil, fmt.Errorf("unable to serialize key: %w", err)
}
// Extract unwrapped public key content.
var info subjectPublicKeyInfo
if _, err = asn1.Unmarshal(out, &info); err != nil {
return nil, fmt.Errorf("unable to extract DER content from the encoded public key: %w", err)
}
// Compute SHA256 checksum of the public key.
h := sha256.Sum256(info.SubjectPublicKey.Bytes)
return h[:], nil
}
func main() {
resp, err := http.DefaultClient.Get("https://www.google.com")
if err != nil {
panic(err)
}
// Create all fingerprint from the certificate chain.
for _, peerCert := range resp.TLS.PeerCertificates {
h, err := PublicKeyFingerprint(peerCert)
if err != nil {
panic(err)
}
fmt.Println(hex.EncodeToString(h))
}
// Extracted from the previous request.
fgr, _ := hex.DecodeString("94af08ac6bbe62bddb9ee8839f18b991290691c0b35db2651b58d98b6a4bea38")
// HTTP client with pinned dialer to enforce remote certificate check.
pinnedClient := &http.Client{
Transport: &http.Transport{
DialTLSContext: PinnedDialer(
// TLS Configuration used to establish initial connection to
// retrieve certificate chain.
&tls.Config{InsecureSkipVerify: true},
// Expected fingerprint, raise an error if no certificate match
// the given fingerprint.
fgr,
),
},
}
if _, err = pinnedClient.Get("https://www.google.com"); err != nil {
panic(err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment