Skip to content

Instantly share code, notes, and snippets.

@Zenithar
Last active October 9, 2023 16:10
Show Gist options
  • Save Zenithar/08765568c440aafb833f6a2cc9ede2dd to your computer and use it in GitHub Desktop.
Save Zenithar/08765568c440aafb833f6a2cc9ede2dd to your computer and use it in GitHub Desktop.
This is an implementation of DHKEM described in RFC9180 - https://datatracker.ietf.org/doc/rfc9180/
package dhkem
import (
"crypto/ecdh"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"hash"
"io"
"golang.org/x/crypto/hkdf"
)
var (
// ErrDeserialization is raised when the given material can't be decoded as
// the expected key type.
ErrDeserialization = errors.New("unable to deserialize key content")
// ErrEncap is raised when an error occured during shared secret encapsulation.
ErrEncap = errors.New("unable to encapsulate the shared secret")
// ErrDecap is raised when an error occured during shared secret decapsulation.
ErrDecap = errors.New("unable to decapsulate the shared secret")
)
// Implements https://www.rfc-editor.org/rfc/rfc9180.html#name-dh-based-kem-dhkem
type dhkem struct {
kemID []byte
curve ecdh.Curve
fh func() hash.Hash
nSecret uint16
nEnc uint16
nPk uint16
nSk uint16
}
// KEMID returns the KEM suite identifier.
func (kem *dhkem) KEMID() []byte {
return kem.kemID
}
// SuiteID returns the public suite identifier used for material derivation.
func (kem *dhkem) SuiteID() []byte {
// suite_id = concat("KEM", I2OSP(kem_id, 2))
return append([]byte("KEM"), kem.kemID...)
}
// GenerateKeyPair generates a key associated to the suite.
func (kem *dhkem) GenerateKeyPair() (*ecdh.PublicKey, *ecdh.PrivateKey, error) {
sk, err := kem.curve.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, fmt.Errorf("unable to generate key pair from the suite: %w", err)
}
return sk.PublicKey(), sk, nil
}
// SerializePublicKey exports the given public key as a byte array.
func (kem *dhkem) SerializePublicKey(pkX *ecdh.PublicKey) []byte {
raw := pkX.Bytes()
if len(raw) != int(kem.nPk) {
panic("invalid public key size")
}
return raw
}
// DeserializePublicKey reads the given content and try to extract a public key
// matching the suite public key type.
func (kem *dhkem) DeserializePublicKey(pkXxm []byte) (*ecdh.PublicKey, error) {
if len(pkXxm) != int(kem.nPk) {
return nil, errors.New("public key data size is invalid")
}
return kem.curve.NewPublicKey(pkXxm)
}
// Encap computes the shared secret and exports encapsulated public key based on
// a remote static public key.
func (kem *dhkem) Encap(pkR *ecdh.PublicKey) (ss, enc []byte, err error) {
// skE, pkE = GenerateKeyPair()
pkE, skE, err := kem.GenerateKeyPair()
if err != nil {
return nil, nil, fmt.Errorf("unable to generate ephemeral keypair: %v: %w", err, ErrEncap)
}
// dh = DH(skE, pkR)
dh, err := skE.ECDH(pkR)
if err != nil {
return nil, nil, fmt.Errorf("unable to compute key agreement: %v: %w", err, ErrEncap)
}
enc = kem.SerializePublicKey(pkE)
if len(enc) != int(kem.nEnc) {
return nil, nil, errors.New("invalid encapsulation size")
}
pkRm := kem.SerializePublicKey(pkR)
// kem_context = concat(enc, pkRm)
kemContext := append(enc, pkRm...)
ss, err = kem.extractAndExpand(dh, kemContext)
if err != nil {
return nil, nil, fmt.Errorf("unable to compute shared secret: %v: %w", err, ErrEncap)
}
return ss, enc, nil
}
// Decap computes the shared secret from the given encapsulated public key and
// a receiver static public key.
func (kem *dhkem) Decap(enc []byte, skR *ecdh.PrivateKey) ([]byte, error) {
if len(enc) != int(kem.nEnc) {
return nil, fmt.Errorf("invalid encapsulation size: %w", ErrDecap)
}
// Try to deserialize received public key.
pkE, err := kem.DeserializePublicKey(enc)
if err != nil {
return nil, fmt.Errorf("unable to deserialize public key: %v: %w", err, ErrDecap)
}
// dh = DH(skR, pkE)
dh, err := skR.ECDH(pkE)
if err != nil {
return nil, fmt.Errorf("unable to compute key agreement: %v: %w", err, ErrDecap)
}
pkRm := kem.SerializePublicKey(skR.PublicKey())
// kem_context = concat(enc, pkRm)
kemContext := append(enc, pkRm...)
// shared_secret = ExtractAndExpand(dh, kem_context)
ss, err := kem.extractAndExpand(dh, kemContext)
if err != nil {
return nil, fmt.Errorf("unable to compute shared secret: %v: %w", err, ErrDecap)
}
return ss, nil
}
// AuthEncap computes a shared secret, and an encapsulated public key based on
// mutual sender and receiver static keys authentication.
func (kem *dhkem) AuthEncap(pkR *ecdh.PublicKey, skS *ecdh.PrivateKey) (ss, enc []byte, err error) {
pkE, skE, err := kem.GenerateKeyPair()
if err != nil {
return nil, nil, fmt.Errorf("unable to generate ephemeral keypair: %w", err)
}
Ze, err := skE.ECDH(pkR)
if err != nil {
return nil, nil, fmt.Errorf("unable to copute ephemeral key agreement: %w", err)
}
Zs, err := skS.ECDH(pkR)
if err != nil {
return nil, nil, fmt.Errorf("unable to compute static key agreement: %w", err)
}
// dh = concat(DH(skE, pkR), DH(skS, pkR))
dh := append(Ze, Zs...)
enc = kem.SerializePublicKey(pkE)
pkRm := kem.SerializePublicKey(pkR)
pkSm := kem.SerializePublicKey(skS.PublicKey())
// kem_context = concat(enc, pkRm)
kemContext := append(enc, pkRm...)
kemContext = append(kemContext, pkSm...)
// shared_secret = ExtractAndExpand(dh, kem_context)
ss, err = kem.extractAndExpand(dh, kemContext)
if err != nil {
return nil, nil, fmt.Errorf("unable to compute shared secret: %w", err)
}
return ss, enc, nil
}
// AuthDecap computes a shared secret from a received encapsulated public key
// based on mutual sender and receiver static keys authentication.
func (kem *dhkem) AuthDecap(enc []byte, skR *ecdh.PrivateKey, pkS *ecdh.PublicKey) ([]byte, error) {
if len(enc) != int(kem.nEnc) {
return nil, errors.New("invalid encapsulation size")
}
// Try to deserialize received public key.
pkE, err := kem.DeserializePublicKey(enc)
if err != nil {
return nil, fmt.Errorf("unable to deserialize public key: %w", err)
}
Ze, err := skR.ECDH(pkE)
if err != nil {
return nil, fmt.Errorf("unable to compute ephemeral key agreement: %w", err)
}
Zs, err := skR.ECDH(pkS)
if err != nil {
return nil, fmt.Errorf("unable to compute static key agreement: %w", err)
}
// dh = concat(DH(skR, pkE), DH(skR, pkS))
dh := append(Ze, Zs...)
enc = kem.SerializePublicKey(pkE)
pkRm := kem.SerializePublicKey(skR.PublicKey())
pkSm := kem.SerializePublicKey(pkS)
// kem_context = concat(enc, pkRm, pkSm)
kemContext := append(enc, pkRm...)
kemContext = append(kemContext, pkSm...)
// shared_secret = ExtractAndExpand(dh, kem_context)
ss, err := kem.extractAndExpand(dh, kemContext)
if err != nil {
return nil, fmt.Errorf("unable to compute shared secret: %w", err)
}
return ss, nil
}
// -----------------------------------------------------------------------------
func (kem *dhkem) extractAndExpand(dh, kemContext []byte) ([]byte, error) {
eaePrk := kem.labeledExtract([]byte(""), []byte("eae_prk"), dh)
return kem.labeledExpand(eaePrk, []byte("shared_secret"), kemContext, kem.nSecret)
}
func (kem *dhkem) labeledExtract(salt, label, ikm []byte) []byte {
// labeled_ikm = concat("HPKE-v1", suite_id, label, ikm)
labeledIKM := append([]byte("HPKE-v1"), kem.SuiteID()...)
labeledIKM = append(labeledIKM, label...)
labeledIKM = append(labeledIKM, ikm...)
return hkdf.Extract(kem.fh, labeledIKM, salt)
}
func (kem *dhkem) labeledExpand(prk, label, info []byte, L uint16) ([]byte, error) {
// labeled_info = concat(I2OSP(L, 2), "HPKE-v1", suite_id, label, info)
var length [2]byte
binary.BigEndian.PutUint16(length[:], L)
labeledInfo := append(length[:], []byte("HPKE-v1")...)
labeledInfo = append(labeledInfo, kem.SuiteID()...)
labeledInfo = append(labeledInfo, label...)
labeledInfo = append(labeledInfo, info...)
r := hkdf.Expand(kem.fh, prk, labeledInfo)
out := make([]byte, L)
if _, err := io.ReadFull(r, out); err != nil {
return nil, fmt.Errorf("unable to generate secret from prf: %w", err)
}
return out, nil
}
package dhkem
import (
"crypto/ecdh"
"crypto/sha256"
"crypto/sha512"
)
// Suite defines the default KEM suite contract.
type Suite interface {
SuiteID() []byte
KEMID() []byte
GenerateKeyPair() (*ecdh.PublicKey, *ecdh.PrivateKey, error)
SerializePublicKey(pkX *ecdh.PublicKey) ([]byte)
DeserializePublicKey(pkXxm []byte) (*ecdh.PublicKey, error)
Encap(pkR *ecdh.PublicKey) (ss, enc []byte, err error)
Decap(enc []byte, skR *ecdh.PrivateKey) ([]byte, error)
AuthEncap(pkR *ecdh.PublicKey, skS *ecdh.PrivateKey) (ss, enc []byte, err error)
AuthDecap(enc []byte, skR *ecdh.PrivateKey, pkS *ecdh.PublicKey) ([]byte, error)
}
// P256HKDFSHA256 defines a KEM Suite based on ECDSA P-256 curve with HKDF-SHA256
// for shared secret derivation.
func P256HKDFSHA256() Suite {
return &dhkem{
kemID: []byte{0x00, 0x10},
curve: ecdh.P256(),
fh: sha256.New,
nSecret: 32,
nEnc: 65,
nPk: 65,
nSk: 32,
}
}
// P384HKDFSHA384 defines a KEM Suite based on ECDSA P-384 curve with HKDF-SHA384
// for shared secret derivation.
func P384HKDFSHA384() Suite {
return &dhkem{
kemID: []byte{0x00, 0x11},
curve: ecdh.P384(),
fh: sha512.New384,
nSecret: 48,
nEnc: 97,
nPk: 97,
nSk: 48,
}
}
// P521HKDFSHA512 defines a KEM Suite based on ECDSA P-521 curve with HKDF-SHA512
// for shared secret derivation.
func P521HKDFSHA512() Suite {
return &dhkem{
kemID: []byte{0x00, 0x12},
curve: ecdh.P521(),
fh: sha512.New,
nSecret: 64,
nEnc: 133,
nPk: 133,
nSk: 66,
}
}
// X25519HKDFSHA256 defines a KEM Suite based on X25519 curve with HKDF-SHA256
// for shared secret derivation.
func X25519HKDFSHA256() Suite {
return &dhkem{
kemID: []byte{0x00, 0x20},
curve: ecdh.X25519(),
fh: sha256.New,
nSecret: 32,
nEnc: 32,
nPk: 32,
nSk: 32,
}
}
package dhkem_test
import (
"testing"
"github.com/stretchr/testify/require"
"zntr.io/dhkem"
)
func TestEncapDecap(t *testing.T) {
suite := dhkem.P256HKDFSHA256()
// Generate long term keys
pk, sk, err := suite.GenerateKeyPair()
require.NoError(t, err)
ss1, enc, err := suite.Encap(pk)
require.NoError(t, err)
ss2, err := suite.Decap(enc, sk)
require.NoError(t, err)
require.Equal(t, ss1, ss2)
}
func TestAuthEncapAuthDecap(t *testing.T) {
suite := dhkem.P256HKDFSHA256()
// Generate long term keys
pkR, skR, err := suite.GenerateKeyPair()
require.NoError(t, err)
pkS, skS, err := suite.GenerateKeyPair()
require.NoError(t, err)
ss1, enc, err := suite.AuthEncap(pkR, skS)
require.NoError(t, err)
ss2, err := suite.AuthDecap(enc, skR, pkS)
require.NoError(t, err)
require.Equal(t, ss1, ss2)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment