Skip to content

Instantly share code, notes, and snippets.

@user202729
Created February 7, 2025 09:51
Show Gist options
  • Save user202729/d95cfcbb50643099a04609f3b7ee9e38 to your computer and use it in GitHub Desktop.
Save user202729/d95cfcbb50643099a04609f3b7ee9e38 to your computer and use it in GitHub Desktop.
Solve z3 problems with pysat
#!/usr/bin/env python3
"""
This script shows the solution of a bit–vector constraint problem using
two different methods.
• solve_with_z3(constraints, variables) uses Z3 directly.
• solve_with_pysat(constraints, variables) first applies bit-blasting and Tseitin CNF
transformation in Z3, then uses PySAT to solve the CNF and reconstructs the model.
A demonstration is provided in main().
See also: https://stackoverflow.com/q/19551225
"""
from z3 import (
BitVec, BitVecVal, Bool, Extract, Goal, Tactic, Then, simplify, sat,
Solver as Z3Solver
)
from pysat.solvers import Solver
# ----- Helper Functions for CNF Traversal and Transformation -----
def extract_clause(expr):
"""
Given an expression that represents a clause (either a literal or an OR of literals),
return a list of literal Z3 expressions.
"""
if expr.decl().kind() == Z3_OP_OR:
return list(expr.children())
else:
return [expr]
def collect_atoms(clauses):
"""
From a list of clause expressions, collect the Boolean atoms (ignoring negations),
using a dictionary mapping the atom's name (str) to its Z3 expression.
"""
atom_dict = {} # type: dict[str, Atom]
for cl in clauses:
for lit in cl:
# If it's a negation, extract the inner atom from ¬(atom)
if lit.decl().kind() == Z3_OP_NOT:
atom = lit.children()[0]
else:
atom = lit
# Using the name of the atom as the key
name = atom.decl().name()
atom_dict[name] = atom
return list(atom_dict.values())
def lit_to_int(lit, atom_name_to_int):
"""
Given a Z3 literal (which may be negated) and a mapping from atom names to integers,
return the integer literal for PySAT (negative if the literal is negated).
"""
if lit.decl().kind() == Z3_OP_NOT:
return -atom_name_to_int[lit.children()[0].decl().name()]
else:
return atom_name_to_int[lit.decl().name()]
# We need Z3_OP_OR and Z3_OP_NOT for inspecting the expressions.
from z3 import Z3_OP_OR, Z3_OP_NOT
# Helper function to add bit-link constraints: create Boolean variables for each bit.
def add_bit_link_constraints(g, bv):
name = bv.decl().name()
bit_width = bv.size()
bits_aux = [Bool(f"{name}[{i}]") for i in range(bit_width)]
for i in range(bit_width):
# Create constraint: (Extract(i,i, bv) == 1) <=> bits_aux[i]
c = (Extract(i, i, bv) == BitVecVal(1, 1)) == bits_aux[i]
g.add(c)
# ----- Solving with Z3 Directly -----
def solve_with_z3(constraints, variables):
"""
Solves the provided high-level Z3 constraints with the given variables.
Returns a Z3 model if satisfiable.
This function does not use bit-blasting.
Arguments:
constraints : list of Z3 constraints
variables : list of Z3 variables (e.g. BitVecs)
"""
solver = Z3Solver()
solver.add(constraints)
# Directly check the model.
if solver.check() == sat:
return solver.model()
else:
return None
# ----- Solving with PySAT -----
def solve_with_pysat(constraints, variables, verbose=False, need_recover=True, solver_name="minisat22"):
"""
Solves the provided constraints by first bit-blasting the input constraints into
CNF using Z3 tactics and then solving the CNF with PySAT.
Returns a Z3 model reconstructed from the auxiliary Boolean assignment.
Arguments:
constraints : list of Z3 constraints
variables : list of Z3 bit-vector variables for which we want to extract the assignment.
need_recover: if True, the function will reconstruct the model from the PySAT assignment.
verbose : if True, print the number of variables and clauses in the CNF.
solver_name : the PySAT solver to use (default is "minisat22").
"""
# Create a goal and add the high-level constraints.
g = Goal()
for c in constraints:
g.add(c)
# Add bit-link constraints for each bit-vector variable.
if need_recover:
for var in variables:
add_bit_link_constraints(g, var)
# Apply chain tactics: bit-blast then convert to CNF using Tseitin transformation.
combined_tactic = Then(Tactic("simplify"), Tactic("bit-blast"), Tactic("simplify"), Tactic("tseitin-cnf"))
cnf_goal = combined_tactic(g)
# Extract CNF clauses (Z3 expressions).
cnf_clauses_expr = []
for subgoal in cnf_goal:
for f in subgoal:
clause = extract_clause(f)
cnf_clauses_expr.append(clause)
if verbose:
print(f"extracted")
# Build mapping from Boolean atoms to integer indices (only consider our auxiliary variables).
atoms = collect_atoms(cnf_clauses_expr)
atoms = list(atoms)
# Sort by name to get reproducible ordering.
atom_names = [atom.decl().name() for atom in atoms]
atom_names.sort()
atom_name_to_int = {atom_name: i+1 for i, atom_name in enumerate(atom_names)}
int_to_atom_name = {i+1: atom_name for i, atom_name in enumerate(atom_names)}
if verbose:
print(f"total number of atoms: {len(atoms)}")
# Convert each clause from Z3 literals to integer literals for PySAT.
cnf = []
for clause in cnf_clauses_expr:
int_clause = [lit_to_int(lit, atom_name_to_int) for lit in clause]
cnf.append(int_clause)
if verbose:
print(f"total number of clauses: {len(cnf)}")
# Use PySAT to solve the CNF.
with Solver(name=solver_name) as solver:
for clause in cnf:
solver.add_clause(clause)
if verbose:
try:
print(f"{solver.nof_vars()} variables, {solver.nof_clauses()} clauses.")
except:
import traceback
traceback.print_exc()
print("[cannot determine number of variables and clauses]")
from time import time
start_time = time()
if not solver.solve():
return None
if verbose:
print(f"Solving time: {time() - start_time:.2f} s.")
pySAT_model = solver.get_model()
# Build assignment mapping for the atoms.
assignment = {}
for lit in pySAT_model:
v = abs(lit)
truth = (lit > 0)
if v in int_to_atom_name:
# Use the Boolean variable names (like "a[0]")
assignment[int_to_atom_name[v]] = truth
# Reconstruct bit-vector values from the assignments for each variable.
def reconstruct_value(prefix, bit_width):
bits = []
for i in range(bit_width):
varname = f"{prefix}[{i}]"
bit = assignment.get(varname, False)
bits.append(bit)
# Assemble bits into an integer (LSB is at index 0).
val = 0
for i, bit in enumerate(bits):
if bit:
val |= (1 << i)
return val
# For each variable, reconstruct its value. We assume that each is a BitVec.
model_mapping = {}
for var in variables:
var_name = var.decl().name()
bit_width = var.size()
val = reconstruct_value(var_name, bit_width)
# Create a Z3 BitVecVal for the variable.
model_mapping[var_name] = BitVecVal(val, bit_width)
# If desired, one can assemble a Z3 model by constraining the variables.
# For simplicity, we return the mapping.
return model_mapping
# ----- Demonstration main() Function -----
def main():
# Example Problem:
# Create two 4-bit bit-vector variables.
a = BitVec('a', 4)
b = BitVec('b', 4)
# High-level constraint: a + b == 6
constraints = [a + b == 6]
variables = [a, b]
print("Solving with Z3 directly...")
z3_model = solve_with_z3(constraints, variables)
if z3_model is not None:
print("Z3 model found:")
print(" a =", z3_model[a])
print(" b =", z3_model[b])
else:
print("Constraints unsatisfiable (Z3).")
print("\nSolving with PySAT (after bit-blasting)...")
from pysat_clasp import patch_add_clasp
patch_add_clasp()
pysat_model = solve_with_pysat(constraints, variables, verbose=True)
if pysat_model is not None:
print("Reconstructed model via PySAT:")
print(" a =", pysat_model['a'])
print(" b =", pysat_model['b'])
else:
print("Constraints unsatisfiable (PySAT).")
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment