Skip to content

Instantly share code, notes, and snippets.

@magical
Last active August 8, 2025 23:22
Show Gist options
  • Save magical/07cad24debe89db60da1bb421046fd51 to your computer and use it in GitHub Desktop.
Save magical/07cad24debe89db60da1bb421046fd51 to your computer and use it in GitHub Desktop.
Schnorr authentication protocol
// https://go.dev/play/p/Y0L8GFSCmZX
// https://go.dev/play/p/JeZf_WHAqy0 (earlier version without hash)
// https://mit6875.github.io/PAPERS/Schnorr-POK-DLOG.pdf
package main
import (
"bytes"
"encoding/binary"
"fmt"
"log"
"filippo.io/bigmod"
"filippo.io/nistec"
cryptorand "crypto/rand"
"crypto/sha256"
)
const order = 0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551
const (
n0 = uint64(order >> (64 * iota) & (1<<64 - 1))
n1
n2
n3
)
var orderBytes = func() (x []byte) {
x = binary.BigEndian.AppendUint64(x, n3)
x = binary.BigEndian.AppendUint64(x, n2)
x = binary.BigEndian.AppendUint64(x, n1)
x = binary.BigEndian.AppendUint64(x, n0)
return
}()
func main() {
log.SetFlags(0)
m, err := bigmod.NewModulus(orderBytes)
//log.Printf("%x, %v", m, err)
bloop(err)
// A:
sk := rand(m)
fmt.Printf("sk = %x\n", sk.Bytes(m))
pk, err := nistec.NewP256Point().ScalarBaseMult(sk.Bytes(m))
bloop(err)
// A:
r := rand(m)
fmt.Printf("r = %x\n", r.Bytes(m))
R, err := nistec.NewP256Point().ScalarBaseMult(r.Bytes(m))
hR := hash(R.BytesCompressed())
// B:
e := rand(m)
fmt.Printf("e = %x\n", e.Bytes(m))
// A:
se := bigmod.NewNat().ExpandFor(m).Add(sk, m).Mul(e, m)
s := bigmod.NewNat().ExpandFor(m).Add(r, m).Sub(se, m)
//log.Printf("%x", s.Bytes(m))
// B:
S, err := nistec.NewP256Point().ScalarBaseMult(s.Bytes(m))
bloop(err)
X, err := nistec.NewP256Point().ScalarMult(pk, e.Bytes(m))
bloop(err)
X.Add(X, S)
hX := hash(X.BytesCompressed())
fmt.Printf("%x %x\n", X.BytesCompressed(), R.BytesCompressed())
fmt.Printf("%x %x\n", hX, hR)
// no need to be constant time, since all inputs are public values (R, s, e, pk)
equal := bytes.Equal(hX, hR)
fmt.Println(equal)
}
func rand(m *bigmod.Modulus) *bigmod.Nat {
var buf = make([]byte, 32)
n := bigmod.NewNat()
for {
cryptorand.Read(buf[:])
if _, err := n.SetBytes(buf, m); err != nil {
log.Println("resampling")
continue
}
return n
}
}
func bloop(err error) {
if err != nil {
log.Fatal(err)
}
}
func hash(b []byte) []byte {
x := sha256.Sum256(b)
return x[:]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment