Skip to content

Instantly share code, notes, and snippets.

@smarter
Created November 6, 2024 13:06
Show Gist options
  • Save smarter/aa541c32ca359f6b2cfac50815f2f34f to your computer and use it in GitHub Desktop.
Save smarter/aa541c32ca359f6b2cfac50815f2f34f to your computer and use it in GitHub Desktop.
#!/usr/bin/python3
from z3 import *
def is_valid_n_bits(n, v):
return And(v >= -(1 << (n-1)), v <= ((1 << (n-1)) - 1))
def are_valid_n_bits(n, *values):
return And(*[is_valid_n_bits(n, v) for v in values])
def is_min_max_n_bits(n, v):
return Or(v == -(1 << (n-1)), v == ((1 << (n-1)) - 1))
def are_min_max_n_bits(n, *values):
return And(*[is_min_max_n_bits(n, v) for v in values])
def wht_4(solver, x00, x01, x10, x11):
# Calculate intermediate values
t1 = x00 - x01 # t1 = x00 - x01
t2 = x10 + x11 # t2 = x10 + x11
# t4 = (t2 - t1) >> 1 = (-x00 + x01 + x10 + x11)/2
t4 = (t2 - t1) >> 1
# Calculate outputs
y00 = x00 + t4 # y00 = (x00 + x01 + x10 + x11)/2
y11 = x11 - t4 # y11 = (x00 - x01 - x10 + x11)/2
y10 = y00 - t2 # y10 = (x00 + x01 - x10 - x11)/2
y01 = t1 - y11 # y01 = (x00 - x01 + x10 - x11)/2
return (y00, y01, y10, y11)
# Example usage:
def check_transformation():
# set_param('parallel.enable', True)
# set_param('verbose', 10)
# set_param('sat.smt', True)
# set_param('smt.bv.solver', 1) # polysat
# solver = SolverFor("BV")
solver = Solver()
input_bits = 4
# Inputs
x00a = BitVec('x00a', input_bits)
x01a = BitVec('x01a', input_bits)
x10a = BitVec('x10a', input_bits)
x11a = BitVec('x11a', input_bits)
x00b = BitVec('x00b', input_bits)
x01b = BitVec('x01b', input_bits)
x10b = BitVec('x10b', input_bits)
x11b = BitVec('x11b', input_bits)
x00c = BitVec('x00c', input_bits)
x01c = BitVec('x01c', input_bits)
x10c = BitVec('x10c', input_bits)
x11c = BitVec('x11c', input_bits)
x00d = BitVec('x00d', input_bits)
x01d = BitVec('x01d', input_bits)
x10d = BitVec('x10d', input_bits)
x11d = BitVec('x11d', input_bits)
# solver.add(are_valid_n_bits(input_bits-3,
# x00a, x01a, x10a, x11a,
# x00b, x01b, x10b, x11b,
# x00c, x01c, x10c, x11c,
# x00d, x01d, x10d, x11d
# ))
solver.add(are_min_max_n_bits(input_bits-3,
x00a, x01a, x10a, x11a,
x00b, x01b, x10b, x11b,
x00c, x01c, x10c, x11c,
x00d, x01d, x10d, x11d
))
xx00a = BitVec('xx00a', input_bits)
xx01a = BitVec('xx01a', input_bits)
xx10a = BitVec('xx10a', input_bits)
xx11a = BitVec('xx11a', input_bits)
xx00b = BitVec('xx00b', input_bits)
xx01b = BitVec('xx01b', input_bits)
xx10b = BitVec('xx10b', input_bits)
xx11b = BitVec('xx11b', input_bits)
xx00c = BitVec('xx00c', input_bits)
xx01c = BitVec('xx01c', input_bits)
xx10c = BitVec('xx10c', input_bits)
xx11c = BitVec('xx11c', input_bits)
xx00d = BitVec('xx00d', input_bits)
xx01d = BitVec('xx01d', input_bits)
xx10d = BitVec('xx10d', input_bits)
xx11d = BitVec('xx11d', input_bits)
solver.add(are_min_max_n_bits(input_bits-3,
xx00a, xx01a, xx10a, xx11a,
xx00b, xx01b, xx10b, xx11b,
xx00c, xx01c, xx10c, xx11c,
xx00d, xx01d, xx10d, xx11d
))
## Transform 1
y00a, y01a, y10a, y11a = wht_4(solver, x00a, x01a, x10a, x11a)
y00b, y01b, y10b, y11b = wht_4(solver, x00b, x01b, x10b, x11b)
y00c, y01c, y10c, y11c = wht_4(solver, x00c, x01c, x10c, x11c)
y00d, y01d, y10d, y11d = wht_4(solver, x00d, x01d, x10d, x11d)
yy00a, yy01a, yy10a, yy11a = wht_4(solver, xx00a, xx01a, xx10a, xx11a)
yy00b, yy01b, yy10b, yy11b = wht_4(solver, xx00b, xx01b, xx10b, xx11b)
yy00c, yy01c, yy10c, yy11c = wht_4(solver, xx00c, xx01c, xx10c, xx11c)
yy00d, yy01d, yy10d, yy11d = wht_4(solver, xx00d, xx01d, xx10d, xx11d)
## Transform 2
z00a, z01a, z10a, z11a = wht_4(solver, y00a, y00b, y00c, y00d)
z00b, z01b, z10b, z11b = wht_4(solver, y01a, y01b, y01c, y01d)
z00c, z01c, z10c, z11c = wht_4(solver, y10a, y10b, y10c, y10d)
z00d, z01d, z10d, z11d = wht_4(solver, y11a, y11b, y11c, y11d)
zz00a, zz01a, zz10a, zz11a = wht_4(solver, yy00a, yy00b, yy00c, yy00d)
zz00b, zz01b, zz10b, zz11b = wht_4(solver, yy01a, yy01b, yy01c, yy01d)
zz00c, zz01c, zz10c, zz11c = wht_4(solver, yy10a, yy10b, yy10c, yy10d)
zz00d, zz01d, zz10d, zz11d = wht_4(solver, yy11a, yy11b, yy11c, yy11d)
# solver.add(Not(are_valid_n_bits(input_bits-1,
# z00a, z01a, z10a, z11a,
# z00b, z01b, z10b, z11b,
# z00c, z01c, z10c, z11c,
# z00d, z01d, z10d, z11d
# )))
## Transform 3
AA00a, AA01a, AA10a, AA11a = wht_4(solver, z00a, zz00b, z00c, zz00d)
AA00b, AA01b, AA10b, AA11b = wht_4(solver, z01a, zz01b, z01c, zz01d)
AA00c, AA01c, AA10c, AA11c = wht_4(solver, z10a, zz10b, z10c, zz10d)
AA00d, AA01d, AA10d, AA11d = wht_4(solver, z11a, zz11b, z11c, zz11d)
solver.add(Not(are_valid_n_bits(input_bits-1,
AA00a, AA01a, AA10a, AA11a,
AA00b, AA01b, AA10b, AA11b,
AA00c, AA01c, AA10c, AA11c,
AA00d, AA01d, AA10d, AA11d
)))
# Check satisfiability
if solver.check() == sat:
model = solver.model()
print("# Found counterexample:")
print(f"x00a = {model.eval(x00a).as_signed_long()}")
print(f"x01a = {model.eval(x01a).as_signed_long()}")
print(f"x10a = {model.eval(x10a).as_signed_long()}")
print(f"x11a = {model.eval(x11a).as_signed_long()}")
print(f"x00b = {model.eval(x00b).as_signed_long()}")
print(f"x01b = {model.eval(x01b).as_signed_long()}")
print(f"x10b = {model.eval(x10b).as_signed_long()}")
print(f"x11b = {model.eval(x11b).as_signed_long()}")
print(f"x00c = {model.eval(x00c).as_signed_long()}")
print(f"x01c = {model.eval(x01c).as_signed_long()}")
print(f"x10c = {model.eval(x10c).as_signed_long()}")
print(f"x11c = {model.eval(x11c).as_signed_long()}")
print(f"x00d = {model.eval(x00d).as_signed_long()}")
print(f"x01d = {model.eval(x01d).as_signed_long()}")
print(f"x10d = {model.eval(x10d).as_signed_long()}")
print(f"x11d = {model.eval(x11d).as_signed_long()}")
print(f"xx00a = {model.eval(xx00a).as_signed_long()}")
print(f"xx01a = {model.eval(xx01a).as_signed_long()}")
print(f"xx10a = {model.eval(xx10a).as_signed_long()}")
print(f"xx11a = {model.eval(xx11a).as_signed_long()}")
print(f"xx00b = {model.eval(xx00b).as_signed_long()}")
print(f"xx01b = {model.eval(xx01b).as_signed_long()}")
print(f"xx10b = {model.eval(xx10b).as_signed_long()}")
print(f"xx11b = {model.eval(xx11b).as_signed_long()}")
print(f"xx00c = {model.eval(xx00c).as_signed_long()}")
print(f"xx01c = {model.eval(xx01c).as_signed_long()}")
print(f"xx10c = {model.eval(xx10c).as_signed_long()}")
print(f"xx11c = {model.eval(xx11c).as_signed_long()}")
print(f"xx00d = {model.eval(xx00d).as_signed_long()}")
print(f"xx01d = {model.eval(xx01d).as_signed_long()}")
print(f"xx10d = {model.eval(xx10d).as_signed_long()}")
print(f"xx11d = {model.eval(xx11d).as_signed_long()}")
# Evaluate outputs
print("\n# Transform 1:")
print(f"y00a = {model.eval(y00a).as_signed_long()}")
print(f"y01a = {model.eval(y01a).as_signed_long()}")
print(f"y10a = {model.eval(y10a).as_signed_long()}")
print(f"y11a = {model.eval(y11a).as_signed_long()}")
print(f"y00b = {model.eval(y00b).as_signed_long()}")
print(f"y01b = {model.eval(y01b).as_signed_long()}")
print(f"y10b = {model.eval(y10b).as_signed_long()}")
print(f"y11b = {model.eval(y11b).as_signed_long()}")
print(f"y00c = {model.eval(y00c).as_signed_long()}")
print(f"y01c = {model.eval(y01c).as_signed_long()}")
print(f"y10c = {model.eval(y10c).as_signed_long()}")
print(f"y11c = {model.eval(y11c).as_signed_long()}")
print(f"y00d = {model.eval(y00d).as_signed_long()}")
print(f"y01d = {model.eval(y01d).as_signed_long()}")
print(f"y10d = {model.eval(y10d).as_signed_long()}")
print(f"y11d = {model.eval(y11d).as_signed_long()}")
print(f"yy00a = {model.eval(yy00a).as_signed_long()}")
print(f"yy01a = {model.eval(yy01a).as_signed_long()}")
print(f"yy10a = {model.eval(yy10a).as_signed_long()}")
print(f"yy11a = {model.eval(yy11a).as_signed_long()}")
print(f"yy00b = {model.eval(yy00b).as_signed_long()}")
print(f"yy01b = {model.eval(yy01b).as_signed_long()}")
print(f"yy10b = {model.eval(yy10b).as_signed_long()}")
print(f"yy11b = {model.eval(yy11b).as_signed_long()}")
print(f"yy00c = {model.eval(yy00c).as_signed_long()}")
print(f"yy01c = {model.eval(yy01c).as_signed_long()}")
print(f"yy10c = {model.eval(yy10c).as_signed_long()}")
print(f"yy11c = {model.eval(yy11c).as_signed_long()}")
print(f"yy00d = {model.eval(yy00d).as_signed_long()}")
print(f"yy01d = {model.eval(yy01d).as_signed_long()}")
print(f"yy10d = {model.eval(yy10d).as_signed_long()}")
print(f"yy11d = {model.eval(yy11d).as_signed_long()}")
print("\n# Transform 2:")
print(f"z00a = {model.eval(z00a).as_signed_long()}")
print(f"z01a = {model.eval(z01a).as_signed_long()}")
print(f"z10a = {model.eval(z10a).as_signed_long()}")
print(f"z11a = {model.eval(z11a).as_signed_long()}")
print(f"z00b = {model.eval(z00b).as_signed_long()}")
print(f"z01b = {model.eval(z01b).as_signed_long()}")
print(f"z10b = {model.eval(z10b).as_signed_long()}")
print(f"z11b = {model.eval(z11b).as_signed_long()}")
print(f"z00c = {model.eval(z00c).as_signed_long()}")
print(f"z01c = {model.eval(z01c).as_signed_long()}")
print(f"z10c = {model.eval(z10c).as_signed_long()}")
print(f"z11c = {model.eval(z11c).as_signed_long()}")
print(f"z00d = {model.eval(z00d).as_signed_long()}")
print(f"z01d = {model.eval(z01d).as_signed_long()}")
print(f"z10d = {model.eval(z10d).as_signed_long()}")
print(f"z11d = {model.eval(z11d).as_signed_long()}")
print(f"zz00a = {model.eval(zz00a).as_signed_long()}")
print(f"zz01a = {model.eval(zz01a).as_signed_long()}")
print(f"zz10a = {model.eval(zz10a).as_signed_long()}")
print(f"zz11a = {model.eval(zz11a).as_signed_long()}")
print(f"zz00b = {model.eval(zz00b).as_signed_long()}")
print(f"zz01b = {model.eval(zz01b).as_signed_long()}")
print(f"zz10b = {model.eval(zz10b).as_signed_long()}")
print(f"zz11b = {model.eval(zz11b).as_signed_long()}")
print(f"zz00c = {model.eval(zz00c).as_signed_long()}")
print(f"zz01c = {model.eval(zz01c).as_signed_long()}")
print(f"zz10c = {model.eval(zz10c).as_signed_long()}")
print(f"zz11c = {model.eval(zz11c).as_signed_long()}")
print(f"zz00d = {model.eval(zz00d).as_signed_long()}")
print(f"zz01d = {model.eval(zz01d).as_signed_long()}")
print(f"zz10d = {model.eval(zz10d).as_signed_long()}")
print(f"zz11d = {model.eval(zz11d).as_signed_long()}")
print("\n# Transform 3:")
print(f"AA00a = {model.eval(AA00a).as_signed_long()}")
print(f"AA01a = {model.eval(AA01a).as_signed_long()}")
print(f"AA10a = {model.eval(AA10a).as_signed_long()}")
print(f"AA11a = {model.eval(AA11a).as_signed_long()}")
print(f"AA00b = {model.eval(AA00b).as_signed_long()}")
print(f"AA01b = {model.eval(AA01b).as_signed_long()}")
print(f"AA10b = {model.eval(AA10b).as_signed_long()}")
print(f"AA11b = {model.eval(AA11b).as_signed_long()}")
print(f"AA00c = {model.eval(AA00c).as_signed_long()}")
print(f"AA01c = {model.eval(AA01c).as_signed_long()}")
print(f"AA10c = {model.eval(AA10c).as_signed_long()}")
print(f"AA11c = {model.eval(AA11c).as_signed_long()}")
print(f"AA00d = {model.eval(AA00d).as_signed_long()}")
print(f"AA01d = {model.eval(AA01d).as_signed_long()}")
print(f"AA10d = {model.eval(AA10d).as_signed_long()}")
print(f"AA11d = {model.eval(AA11d).as_signed_long()}")
else:
print("No counterexample found")
if __name__ == "__main__":
check_transformation()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment