Skip to content

Instantly share code, notes, and snippets.

@AdamISZ
Created April 25, 2021 16:24
Show Gist options
  • Save AdamISZ/3c09bc235654b28ca1192aa1f21fe3ce to your computer and use it in GitHub Desktop.
Save AdamISZ/3c09bc235654b28ca1192aa1f21fe3ce to your computer and use it in GitHub Desktop.
MuSig2 toy implementation in Python for learning purposes
""" THIS CODE IS ONLY EDUCATIONAL - NO
PART OF IT IS FIT FOR PRODUCTION USE.
NO, SERIOUSLY, I MEAN IT!!
As for reading it, start with the `__main__`
section at the bottom and go from there.
Comments are, deliberately, voluminous.
If you want to run the example, just:
(a) install Joinmarket (else see the notes on import)
(b) `source jmvenv/bin/activate` from joinmarket-clientserver directory
(c) run this script as `python musig2-demo.py 3 "hello"`, for
3 participants signing "hello", or change arguments as preferred.
There is no obvious limit on the number of participants, although
1 does not work :)
A last generic comment: see the comments in `schnorr_challenge()` as
for why these signatures would not currently be compatible with a real
BIP340 implementation but (a) that's probably for the best and (b) should
be quite trivial to fix.
"""
import os
import sys
import hashlib
# needs installation of Joinmarket;
# note otherwise you will need implementations
# of multiply (scalar mult of secp256k1 curve point),
# add_pubkeys (add list of curve points),
# privkey-to-pubkey (could be done with multiply),
# and bin-to-hex conversion.
from jmbitcoin import multiply, add_pubkeys, privkey_to_pubkey
from jmbase import bintohex
# we are doing arithmetic in the curve group; see notes below,
# this is very much for education only!
N = 115792089237316195423570985008687907852837564279074904382605163141518161494337
def flip_priv_if_not_even_y(scalar):
""" In some circumstances, we will be using only
scalars that correspond to points (scalar * G) with even
y-coordinates (see BIP340 for details).
Returns:
`scalar, False` if scalar * G has even y,
else returns `N - scalar, True`.
"""
flipped = False
scalar_int = int.from_bytes(scalar, byteorder="big")
if privkey_to_pubkey(scalar + b"\x01")[0] != 2:
scalar_int = N - scalar_int
flipped = True
return ((scalar_int).to_bytes(32, byteorder="big"), flipped)
def flip_pub_if_not_even_y(P):
""" Returns (schnorr-style serialized pubkey, False) if
P has even y-coord, else returns (schnorr-style-serialized (-P), True)
if it did not (and so was flipped).
"""
flipped = False
if P[0] != 2:
flipped = True
P = multiply((N-1).to_bytes(32, byteorder="big"), P)
return (P[1:], flipped)
def negate_scalar(scalar):
"""Given a 32 byte scalar 0 < x < N, replace
with -x mod N (which = N - x), reserialized as
a 32 byte string. Note no bounds checking is
done here as should be in a serious implementation.
"""
return (N - int.from_bytes(scalar,
byteorder="big")).to_bytes(32, byteorder="big")
def schnorr_create_priv():
""" The normal private key creation routine,
but for Schnorr BIP340.
"""
priv = os.urandom(32)
return flip_priv_if_not_even_y(priv)[0]
### Translation routines between Joinmarket's bitcoin backend,
### which uses the legacy key types, to the new BIP340 style
### keytypes:
def schnorr_priv_to_pub(priv):
""" Provide a 32 byte string as private key,
a 32 byte BIP340 style Schnorr public key will be returned.
Note the even ness tiebreaker applied in BIP340, which is
relevant to sign/verify here. In practice that just means
ditching the prefix byte.
"""
# to agree with the pre-existing convention that *compressed*
# pubkeys are only returned if the privkey has 01 appended,
# we add it here.
return privkey_to_pubkey(priv + b"\x01")[1:]
def schnorr_tweak_mult(scalar, pubkey):
""" Takes a 32 bytes-encoded Schnorr pubkey,
multiplies by a 32 byte scalar and returns a new
pubkey serialized, but *with* the traditional sign byte,
since it must encode if it's positive/negative anyway
(something we can't control, here).
"""
return multiply(scalar, b"\x02" + pubkey)
def schnorr_add_pubkeys(pubkeys):
""" Takes an iterable of schnorr style serialized
public keys, adds them as curve points and returns
the sum as another serialized pubkey, this time *with*
the traditional sign byte since we can't control
if it's positive/negative, here.
"""
return add_pubkeys([b"\x02"+x for x in pubkeys])
def schnorr_challenge(R, pk, message):
""" Provided a Schnorr-serialized nonce point R and pubkey pk,
and a binary string as `message`, construct the challenge hash
for Schnorr signing. Note:
As per BIP340 we use H(R||P||m),
but we do NOT tag as there (setting midstate by hashing "BIP340/challenge",
see https://github.com/bitcoin-core/secp256k1/blob/1e5d50fa93d71d751b95eec6a80f6732879a0071/src/modules/schnorrsig/main_impl.h#L96-L98
),. Noting this in case it is unobvious why these signatures would not
verify as BIP340 valid. This can obviously be easily fixed if required.
"""
return hashlib.sha256(b''.join([R, pk, message])).digest()
def schnorr_sign(priv, message, k=None, R=None, P=None):
""" Provide a schnorr private key, a message (binary string)
to be signed, and optionally a 32 byte nonce (else will be generated
here *randomly*, not deterministically as RFC6979 or BIP340).
Note: this is the more generic construction in which the public nonce
commitment R, fed into the challenge hash, may be different from simply
k*G, and so may be the pubkey (P); this is required for non-vanilla-Schnorr,
like MuSig2.
If R is not specified, we stick to vanilla case (R=kG).
returns:
(R, s)
where R is a Schnorr pubkey and s is a 32 byte binary string.
"""
if not P:
P = schnorr_priv_to_pub(priv)
if not k:
k = os.urandom(32)
if not R:
R = schnorr_priv_to_pub(k)
e = schnorr_challenge(R, P, message)
# this is not a safe way to do cryptographic operations;
# timing sidechannel attacks are only *one* reason. But this
# is the simplest and easiest to understand for educational
# purposes:
# convert k, priv and e to integers:
k_int, priv_int, e_int = (int.from_bytes(x,
byteorder="big") for x in [k, priv, e])
assert k_int != 0
# apply Schnorr algebraically:
sig_int = (k_int + priv_int * e_int) % N
s = (sig_int).to_bytes(32, byteorder="big")
# signature is returned as pair (R, s), two 32-byte strings:
print("Sign produced R: {}, s: {}".format(bintohex(R), bintohex(s)))
print("against this pubkey: {}".format(bintohex(P)))
print("For message: {}".format(bintohex(message)))
return (R, s)
def schnorr_verify(pub, message, sig):
""" Notice that this is the GENERIC
operation, which knows nothing about whether
the signature was created from a single key
or through aggregation.
Given a signature in the form (R, s) with both
being 32 byte strings, and a pubkey (another 32
byte string), and a byte string as message, return
True if and only if the signature verifies.
"""
R, s = sig
# note that BIP340 specifies important additional checks:
# even-y condition of the R curve point, and it not being the
# point at infinity (additive identity).
e = schnorr_challenge(R, pub, message)
eP = schnorr_tweak_mult(e, pub)
sG = privkey_to_pubkey(s + b"\x01")
return sG == add_pubkeys([b"\x02" + R, eP])
def serialize_keyset(keyset):
""" Trivial but factored out in case a different
way of serializing a set of keys is preferred.
"""
return b"".join(keyset)
def get_b_coeff(i, meta_pubkey, nonceset, message, size, as_int=True):
""" Get the b-coefficients required for MuSig2 nonce modification.
Note: `meta_pubkey` here must be the final/total aggregated key
as used in Schnorr signing after negotiation between participants,
and, clearly, the ordering of the set of nonces in nonceset affects
the outcome so must be explicit.
"""
if i == 1:
b = b"\x01"
else:
# small note, the Python snippet `[a for b in x for a in b]`
# is a "list-flattener": [[1,2], [3,4]] --> [1,2,3,4]; it's
# essential that we hash-in *every* sub-nonce of every
# participant.
b = hashlib.sha256(b"".join([str(i).encode(),
meta_pubkey,
serialize_keyset(
[item for sublist in nonceset for item in sublist]),
message])).digest()
if not as_int:
return b
return int.from_bytes(b, byteorder="big")
class BasicMuSig2SigningSession(object):
""" Class to encapsulate the state of one of
the participants in a MuSig2 signing session.
"""
# defines how many sub-nonces to use per-participant
# see details in MuSig2 paper for why 5 or 2
# are reasonable choices.
nu = 5
def __init__(self, name, i, message, size):
# number of keys in the multisign:
self.size = size
# for convenience, a string identifier:
self.name = name
# our index in the list of keys:
self.i = i
# the message to be signed:
self.message = message
# our key:
self.priv = None
# the set of all counterparties' original pubkeys:
self.keyset = [None]*size
# as above but for the "meta-keys" a_i * P_i:
self.meta_keys = [None]*size
# our meta-privkey is a_i * x_i where i is our index:
self.meta_privkey = None
# our private nonce scalars:
self.base_nonces_k = [None]*self.nu
# all publically shared base nonces as a list of lists:
self.base_nonces = [None]*self.size
# our private nonce for signing after aggregation:
self.aggregate_nonce_scalar = None
# each counterparty's partial signature s_i = k_i' + e * a_i * x_i
self.partial_sigs = [None]*self.size
# the final aggregated pubkey sigma a_i * P_i = P~
self.full_aggregate_pubkey = None
# the final aggregate signature s:
self.full_signature = None
def generate_nonceset(self):
""" Generate self.nu base nonces; note that they
are valid Schnorr keys, i.e. even y.
"""
self.base_nonces_k = list(schnorr_create_priv(
) for _ in range(self.nu))
self.base_nonces[self.i] = [schnorr_priv_to_pub(
x) for x in self.base_nonces_k]
def generate_key(self):
""" Generates a Schnorr-valid private/public
key pair.
"""
assert not self.priv
self.priv = schnorr_create_priv()
self.pub = schnorr_priv_to_pub(self.priv)
self.keyset[self.i] = self.pub
def meta_key(self):
""" Sets keys for the MuSig(2) style
signing operation. Must be called after
this user's key and the entire set of pubkeys
in the signing operation (self.keyset) have
been defined.
"""
for i in range(self.size):
a = hashlib.sha256(b"".join([
serialize_keyset(self.keyset),
self.keyset[i]])).digest()
a_int = int.from_bytes(a, byteorder="big")
# note: we do *not* attempt to coerce this intermediate
# key to be even-y; we only need the final aggregate key to
# be even-y.
self.meta_keys[i] = schnorr_tweak_mult(a, self.keyset[i])
if i == self.i:
priv_int = int.from_bytes(self.priv, byteorder="big")
mpriv_int = (a_int * priv_int) % N
self.meta_privkey = (mpriv_int).to_bytes(32, byteorder="big")
assert privkey_to_pubkey(self.meta_privkey + b"\x01") == self.meta_keys[self.i]
self.full_aggregate_pubkey = add_pubkeys(self.meta_keys)
# as per above comment, force the aggregate to be even-y and record if
# this was needed:
self.full_aggregate_pubkey, flipped = flip_pub_if_not_even_y(
self.full_aggregate_pubkey)
self.full_aggregate_pubkey_flipped = flipped
def set_counterparty_key(self, i, pk):
""" Fill in the key for participant i in
the keyset; when complete, calculate our meta-key.
"""
assert i != self.i
assert self.meta_keys[self.i] is None
assert pk != self.pub
# various other checks would be appropriate
# here in a serious implementation.
self.keyset[i] = pk
# if this was the last other counterparty,
# we now have the full set of inital pubkeys,
# so we can now set our own meta-privkey, the other
# meta-pubkeys, and the final overall aggregate pubkey:
if all([x for x in self.keyset]):
self.meta_key()
def get_round1_message(self):
""" In round 1 each signer must send its base nonces.
"""
if not self.base_nonces[self.i]:
self.generate_nonceset()
return b"".join([self.pub] + self.base_nonces[self.i])
def receive_round1_message(self, msg, counterparty):
""" For counterparty index counterparty,
receive a message msg which is serialized:
32 bytes key
self.nu * 32 bytes: base nonces for this participant.
"""
assert len(msg) == (self.nu + 1) * 32
self.set_counterparty_key(counterparty, msg[:32])
self.set_counterparty_base_nonces(msg[32:], counterparty)
if self.full_aggregate_pubkey:
self.calculate_aggregate_nonce()
def set_counterparty_base_nonces(self, msg, counterparty_index):
""" Receive the set of nonces R_i, j for participant i ==
counterparty_index, serialized as a set of 32 byte strings.
"""
expected_len = 32 * self.nu
self.base_nonces[counterparty_index] = list(
msg[start:start+32] for start in range(0, expected_len, 32))
def calculate_aggregate_nonce(self):
self.calculate_full_aggregate_public_nonce()
self.calculate_our_aggregate_nonce()
def calculate_our_aggregate_nonce(self):
""" This calculates the nonce aggregation for *our*
partial index, so:
k_1 = k_1,1 + b_2 k_1,2 + b_3 k_1,3 + b_4 k_1,4 + b_5 k_1,5
(if our index in the keyset is 1).
"""
k_i = int.from_bytes(self.base_nonces_k[0], byteorder="big")
for i in range(2, self.nu + 1):
b = get_b_coeff(i, self.full_aggregate_pubkey,
self.base_nonces, self.message,
self.size)
k_i = (k_i + b * int.from_bytes(self.base_nonces_k[i-1],
byteorder="big")) % N
self.aggregate_nonce_scalar = (k_i).to_bytes(32, byteorder="big")
def calculate_full_aggregate_public_nonce(self):
""" This calculates the full public nonce R,
which is required to create the correct Schnorr challenge
hash for each partial signature:
R = (R_1,1 + .. + R_n,1) + b_2(R_1,2 + .. + R_n,2) + .. + b_5(R_1,5 + .. + R_n,5)
"""
rc = [0] * self.nu
# index 1 is a special case (see MuSig2 paper for justification b1=1):
rc[0] = schnorr_add_pubkeys([x[0] for x in self.base_nonces])
for i in range(2, self.nu + 1):
rs = schnorr_add_pubkeys([x[i-1] for x in self.base_nonces])
rc[i-1] = multiply(get_b_coeff(i,
self.full_aggregate_pubkey, self.base_nonces,
self.message, self.size, as_int=False), rs)
# note that all the keys in `rc` were not controlled to be even-y-only.
# hence we need to apply the flip operation here and remember if we
# did it.
self.full_aggregate_public_nonce = add_pubkeys(rc)
self.full_aggregate_public_nonce, flipped = flip_pub_if_not_even_y(
self.full_aggregate_public_nonce)
self.full_aggregate_public_nonce_flipped = flipped
def get_round2_message(self):
""" Once all keys and full nonce are set,
we can Schnorr sign that and return it as our partial signature;
this is what needs to be sent to counterparties as round 2.
"""
# Note a subtlety: in partial sigs, the meaning of 'R' is not
# quite as normal; so returning partial sigs means returning
# only the 's' value.
# Also note that we flip to using negatives of 'k' and 'x' depending
# on whether it was already discovered that the *aggregate* R and P
# values were or were not even.
# This way, if these keys' signs had to be flipped, *every*
# participant will flip the corresponding scalar values to make it
# match after addition.
priv = negate_scalar(self.meta_privkey) if \
self.full_aggregate_pubkey_flipped else self.meta_privkey
k = negate_scalar(self.aggregate_nonce_scalar) if \
self.full_aggregate_public_nonce_flipped else \
self.aggregate_nonce_scalar
# we're now ready to create and transmit the partial signature s_i.
# Since neither R = kG nor P = xG applies here, we must specify those
# keys explicitly to be passed into the hash challenge:
R, s = schnorr_sign(priv, self.message,
k=k,
R=self.full_aggregate_public_nonce,
P=self.full_aggregate_pubkey)
self.partial_sigs[self.i] = s
def receive_round2_message(self, msg, counterparty):
""" Receive s_n values from counterparties.
These will be 32 byte strings (same note as above
re: 'R' in (R, s)).
"""
assert len(msg) == 32
# a thorough implementation would make checks at this point:
self.partial_sigs[counterparty] = msg
# once all have arrived we can automatically construct
# the full signature:
if all(self.partial_sigs):
full_sig = 0
for i in range(self.size):
full_sig = (full_sig + int.from_bytes(
self.partial_sigs[i], byteorder="big")) % N
full_sig_serialized = (full_sig).to_bytes(32, byteorder="big")
print("We have completed the multisignature.")
print("It is: ")
# show as a single 64 byte string R::s
rhex = bintohex(self.full_aggregate_public_nonce)
shex = bintohex(full_sig_serialized)
print("".join([rhex, shex]))
self.full_signature = (self.full_aggregate_public_nonce, full_sig_serialized)
if __name__ == "__main__":
# define the size of the multisig group,
# taken as first command line argument to the script:
size = int(sys.argv[1])
# define the message to be signed;
# taken as second command line argument to the script.
# It is encoded as a byte string:
msg = sys.argv[2].encode("utf-8")
# Creat that many signing instances:
participants = []
for i in range(size):
participants.append(BasicMuSig2SigningSession(
"Participant" + str(i), i, msg, size))
"""Having created the participants, a brief sketch of what
must happen:
1. Each party sends their pubkey for signing (presumably
freshly generated).
2. They also send in the same (first) message a list of 5
(or 2) nonces, which are also really pubkeys/curve points.
3. When everyone receives everyone else's first message, they
can all calculate the aggregate public key (P~) and the
aggregate public nonce (R~). Using that they can calculate
a partial signature for themselves (using their private
key as well, of course), called s_i for the i-th participant.
4. They all send s_i to each other.
5. Everyone can now calculate the full signature as
s~ = s_1 + s_2 + ... and combine it with R~ to get a
normal-looking Schnorr signature (R~, s~) against the
public key P~.
6. Lastly everyone (including the Bitcoin network, say) can
verify that (R~, s~) is valid against P~ and so e.g. a
transaction will be counted valid.
"""
# Each participant should generate its set of base nonces.
# by default we have 5 each, according to the ROM proof in
# the MuSig2 paper this is appropriate, however using 2 instead
# is also considered provably secure (AGM+ROM) (see paper for
# the gory details!).
for i in range(size):
participants[i].generate_key()
participants[i].generate_nonceset()
# Now each party must broadcast to the other parties, their base
# nonce set, along with their proposed signing key:
from itertools import combinations
for a, b in combinations(participants, 2):
# TODO these two calls are needlessly duplicated
msga = a.get_round1_message()
msgb = b.get_round1_message()
a.receive_round1_message(msgb, b.i)
b.receive_round1_message(msga, a.i)
# When all of those messages are exchanged, all participants
# have all base nonces and all base pubkeys, and will have also
# calculated the full aggregate nonce.
for participant in participants:
participant.get_round2_message()
for a, b in combinations(participants, 2):
a.receive_round2_message(b.partial_sigs[b.i], b.i)
b.receive_round2_message(a.partial_sigs[a.i], a.i)
# At some point in the above loop, each participant will
# fill their list of partial signatures, and automatically
# print out the total/final Schnorr multisignature (R, s).
# Now let's check: are the signatures and keys created by
# all participants the same? And most importantly, do they
# verify as a valid vanilla Schnorr signature?
full_sigs = [x.full_signature for x in participants]
assert full_sigs.count(full_sigs[0]) == size
agg_pubkeys = [x.full_aggregate_pubkey for x in participants]
assert agg_pubkeys.count(agg_pubkeys[0]) == size
assert schnorr_verify(participants[0].full_aggregate_pubkey, msg,
participants[0].full_signature)
print("Success! The created signature verifies as a normal Schnorr signature.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment