Skip to content

Instantly share code, notes, and snippets.

@AdamISZ
Last active April 15, 2025 17:24
Show Gist options
  • Save AdamISZ/1012d90d47c53a9d4f99995877548b4f to your computer and use it in GitHub Desktop.
Save AdamISZ/1012d90d47c53a9d4f99995877548b4f to your computer and use it in GitHub Desktop.
Back's LSAG variant in Python
### This module illustrates how Back's variant
# of the LSAG (linkable spontaneous anonymous group)
# signature of Liu Wei Wong 2004, works.
# There is no command line tool, only a set of tests
# as sanity checks that the algorithm is correct.
#
# To use, pip install python-bitcointx (which comes from:
# https://github.com/Simplexum/python-bitcointx )
# and note that installation of this will only work if it
# succeeds in finding libsecp256k1 on your system.
# The actual algo is only in the sign() and verify() methods,
# and the test functions at the end show how to use it.
# Most of the module is just guff that sets up all the ECC
# operations "nicely" (albeit with holes, only for testing)
# ... the only nice thing here is that the ECC operations are
# decently fast.
from typing import List
import binascii
import unittest
import struct
import hashlib
import os
from bitcointx.core.key import CPubKey, CKey
from bitcointx.core.secp256k1 import get_secp256k1
# This extra function definition, not present in the
# underlying bitcointx library, is to allow
# multiplication of pubkeys by scalars
secp_obj = get_secp256k1()
import ctypes
secp_obj.lib.secp256k1_ec_pubkey_tweak_mul.restype = ctypes.c_int
secp_obj.lib.secp256k1_ec_pubkey_tweak_mul.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_char_p]
groupN = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
infty = "INFTY"
def bintohex(b):
return binascii.hexlify(b).decode('utf8')
class Scalar(object):
def __init__(self, x):
self.x = x % groupN
def to_bytes(self):
return int.to_bytes(self.x, 32, byteorder="big")
@classmethod
def from_bytes(cls, b):
return cls(int.from_bytes(b, byteorder="big"))
@classmethod
def pow(cls, base, exponent):
return cls(pow(base, exponent, groupN))
def __add__(self, other):
if isinstance(other, int):
y = other
elif isinstance(other, Scalar):
y = other.x
return Scalar((self.x + y) % groupN)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
if isinstance(other, int):
temp = other
elif isinstance(other, Scalar):
temp = other.x
return Scalar((self.x - temp) % groupN)
def __rsub__(self, other):
if isinstance(other, int):
temp = other
elif isinstance(other, Scalar):
temp = other.x
else:
assert False
return Scalar((temp - self.x) % groupN)
def __mul__(self, other):
if other == 1:
return self
if other == 0:
return Scalar(0)
return Scalar((self.x * other.x) % groupN)
def __rmul__(self, other):
return self.__mul__(other)
def __neg__(self):
return Scalar(-self.x % groupN)
def __str__(self):
return str(self.x)
def __repr__(self):
return str(self.x)
def __eq__(self, other):
if isinstance(other, Scalar):
return self.x == other.x
return False
def getG(compressed: bool = True) -> CPubKey:
"""Returns the public key binary
representation of secp256k1 G;
note that CPubKey is of type bytes.
"""
priv = b"\x00"*31 + b"\x01"
k = CKey(priv, compressed=compressed)
G = k.pub
return G
def add_pubkeys(pubkeys: List[bytes]) -> CPubKey:
'''Input a list of binary compressed pubkeys
and return their sum as a binary compressed pubkey.'''
pubkey_list = [CPubKey(x) for x in pubkeys]
if not all([x.is_compressed() for x in pubkey_list]):
raise ValueError("Only compressed pubkeys can be added.")
if not all([x.is_fullyvalid() for x in pubkey_list]):
raise ValueError("Invalid pubkey format.")
return CPubKey.combine(*pubkey_list)
def multiply(s: bytes, pub: bytes, return_serialized: bool = True) -> bytes:
'''Input binary compressed pubkey P(33 bytes)
and scalar s(32 bytes), return s*P.
The return value is a binary compressed public key,
or a PublicKey object if return_serialized is False.
Note that the called function does the type checking
of the scalar s.
('raw' options passed in)
'''
try:
CKey(s)
except ValueError:
raise ValueError("Invalid tweak for libsecp256k1 "
"multiply: {}".format(bintohex(s)))
pub_obj = CPubKey(pub)
if not pub_obj.is_fullyvalid():
raise ValueError("Invalid pubkey for multiply: {}".format(
bintohex(pub)))
privkey_arg = ctypes.c_char_p(s)
pubkey_buf = pub_obj._to_ctypes_char_array()
ret = secp_obj.lib.secp256k1_ec_pubkey_tweak_mul(
secp_obj.ctx.verify, pubkey_buf, privkey_arg)
if ret != 1:
assert ret == 0
raise ValueError('Multiplication failed')
if not return_serialized:
return CPubKey._from_ctypes_char_array(pubkey_buf)
return bytes(CPubKey._from_ctypes_char_array(pubkey_buf))
def pointadd(points):
# NB this is not correct as it does not account for cancellation;
# not sure how a return is serialized as point at infinity in that case.
# (but it doesn't happen in the uses in this module).
pointstoadd = [x for x in points if x != infty]
if len(pointstoadd) == 1:
return pointstoadd[0]
if len(pointstoadd) == 0:
return infty
return add_pubkeys(pointstoadd)
def pointmult(multiplier, point):
# given a Scalar 'multiplier' as a binary string,
# and a pubkey 'point', returns multiplier*point
# as another pubkey object
if multiplier == 0:
return infty
if multiplier == 1:
return point
if multiplier == Scalar(0):
return infty
return multiply(multiplier.to_bytes(), point, return_serialized=False)
def hash_iterable(elements) -> Scalar:
"""Hashes an iterable of elements where each element implement is bytes-castable"""
data = b''.join(bytes(element) for element in elements)
return Scalar.from_bytes(hashlib.sha256(data).digest())
def generate_random_scalar():
return Scalar.from_bytes(os.urandom(32))
def generate_keyset(n):
G = getG()
privatekeys = [generate_random_scalar() for _ in range(n)]
keyset = [pointmult(privatekeys[i], G) for i in range(n)]
return (privatekeys, keyset)
def hashtopoint(point):
index = 0
nums_point = None
seed = bytes(point) + struct.pack(b'B', index)
for counter in range(256):
seed_c = seed + struct.pack(b'B', counter)
hashed_seed = hashlib.sha256(seed_c).digest()
#Every x-coord on the curve has two y-values, encoded
#in compressed form with 02/03 parity byte. We just
#choose the former.
claimed_point = b"\x02" + hashed_seed
try:
nums_point = CPubKey(claimed_point)
# CPubKey does not throw ValueError or otherwise
# on invalid initialization data; it must be inspected:
assert nums_point.is_fullyvalid()
return nums_point
except:
continue
assert False, "It seems inconceivable, doesn't it?" # pragma: no cover
def verify(sigma, message, keyset, keyimage):
G = getG()
# we treat sigma as the composite: s_0 .. s_n-1, e_0:
n = len(keyset)
svals = sigma[:n]
e0claim = sigma[-1]
if len(sigma) != n+1:
raise Exception
# start at index 0, using the claimed e0. Only at the end check
# that the same e0 is reproduced.
current_e = e0claim
for i in range(n):
R1 = pointadd([pointmult(svals[i], G), pointmult(-current_e, keyset[i])])
HP = hashtopoint(keyset[i])
R2 = pointadd([pointmult(svals[i], HP), pointmult(-current_e, keyimage)])
current_e = hash_iterable([message, R1, R2])
# for convenience of linking checks (out of scope here),
# returning the keyimage also:
return current_e == e0claim, keyimage
def sign(privatekey, index, message, otherkeyset):
G = getG()
ourkey = pointmult(privatekey, G)
keyimagebase = hashtopoint(ourkey)
keyimage = pointmult(privatekey, keyimagebase)
# add our key to the keyset
keyset = otherkeyset[:index] + [ourkey] + otherkeyset[index:]
k_i = generate_random_scalar()
R_i = pointmult(k_i, G)
R_iprime = pointmult(k_i, keyimagebase)
e = [None]*len(keyset)
s = [None]*len(keyset)
e[(index + 1) % len(keyset)] = hash_iterable([message, R_i, R_iprime])
for i in range(index + 1, index + len(keyset)):
idx = i % len(keyset)
s[idx] = generate_random_scalar()
R = pointadd([pointmult(s[idx], G), pointmult(-e[idx], keyset[idx])])
Rprime = pointadd([pointmult(s[idx], hashtopoint(keyset[idx])), pointmult(-e[idx], keyimage)])
next_idx = (idx + 1) % len(keyset)
e[next_idx] = hash_iterable([message, R, Rprime])
# finally fill in the non-faked s-value:
s[index] = k_i + e[index] * privatekey
sigma = s + [e[0]]
return (sigma, message, keyset, keyimage)
class TestRingSignature(unittest.TestCase):
def test_signature_verification(self):
"""Basic test: sign and verify a message with a valid signature."""
n = 5 # Ring size
message = os.urandom(32)
privkeys, keyset = generate_keyset(n)
sigma, msg, used_keyset, keyimage = sign(privkeys[2], 2, message, keyset[:2] + keyset[3:])
self.assertTrue(verify(sigma, message, used_keyset, keyimage)[0])
def test_invalid_signature_message_tampered(self):
"""Signature should fail if the message is modified."""
n = 5
message = os.urandom(32)
privkeys, keyset = generate_keyset(n)
sigma, msg, used_keyset, keyimage = sign(privkeys[2], 2, message, keyset[:2] + keyset[3:])
tampered_message = b"modified" + message[8:]
self.assertFalse(verify(sigma, tampered_message, used_keyset, keyimage)[0])
def test_invalid_signature_sigma_tampered(self):
"""Signature should fail if one of the s-values is changed."""
n = 5
message = os.urandom(32)
privkeys, keyset = generate_keyset(n)
sigma, msg, used_keyset, keyimage = sign(privkeys[2], 2, message, keyset[:2] + keyset[3:])
sigma[0] += 1 # Tamper with first s-value
self.assertFalse(verify(sigma, message, used_keyset, keyimage)[0])
def test_invalid_signature_keyset_modified(self):
"""Signature should fail if the keyset is modified."""
n = 5
message = os.urandom(32)
privkeys, keyset = generate_keyset(n)
sigma, msg, used_keyset, keyimage = sign(privkeys[2], 2, message, keyset[:2] + keyset[3:])
modified_keyset = used_keyset[:]
modified_keyset[0] = pointmult(generate_random_scalar(), used_keyset[0]) # Tamper one key
self.assertFalse(verify(sigma, message, modified_keyset, keyimage)[0])
def test_keyimage_reuse_detected(self):
"""A reused private key should yield the same key image, allowing linkage."""
n = 5
message1 = os.urandom(32)
message2 = os.urandom(32)
message3 = os.urandom(32)
privkeys, keyset = generate_keyset(n)
sigma1, _, used_keyset1, keyimage1 = sign(privkeys[2], 2, message1, keyset[:2] + keyset[3:])
result1, tag1 = verify(sigma1, message1, used_keyset1, keyimage1)
self.assertTrue(result1)
sigma2, _, used_keyset2, keyimage2 = sign(privkeys[2], 2, message2, keyset[:2] + keyset[3:])
result2, tag2 = verify(sigma2, message2, used_keyset2, keyimage2)
self.assertTrue(result2)
sigma3, _, used_keyset3, keyimage3 = sign(privkeys[1], 1, message3, keyset[:1] + keyset[2:])
result3, tag3 = verify(sigma3, message3, used_keyset3, keyimage3)
self.assertTrue(result3)
# The keyimages (linking tags) should be the same for reused key
self.assertEqual(tag1, tag2)
# and should be different for different keys:
self.assertFalse(tag2 == tag3)
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment