Last active
March 26, 2022 06:09
-
-
Save mb64/e178dd9893ae13d4f22241963770f6b2 to your computer and use it in GitHub Desktop.
Itty bitty SMT solver: DPLL(T) where T = equality. Likely buggy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* Itty bitty SMT solver: DPLL(T) where T = equality *) | |
(* | |
# let x, y = 0, 1 (* integer variable IDs *) ;; | |
# let prob = SMT.new_problem 2 (* 2 for two variables *) ;; | |
# let b = SMT.new_bool prob;; | |
# SMT.add_clause prob [b; SMT.eq prob x y]; | |
SMT.add_clause prob [SMT.not b]; | |
SMT.solve prob;; | |
- : SMT.response = SMT.SAT | |
# SMT.add_clause prob [SMT.not (SMT.eq prob x y)]; | |
SMT.solve prob;; | |
- : SMT.response = SMT.UNSAT | |
*) | |
module SAT : sig | |
type atom | |
val not : atom -> atom | |
type clause = atom list | |
type problem | |
val no_problem : problem | |
val add_var : problem -> atom * problem | |
val add_clause : problem -> atom list -> problem | |
type soln = Yes | No | IDK | |
(* dpll and its callback both raise Unsat if it's unsat. | |
TODO: have the callback provide an unsat core to learn *) | |
exception Unsat | |
val dpll : problem -> ((atom -> soln) -> unit) -> unit | |
end = struct | |
type atom = int | |
let not = lnot | |
let is_neg x = x < 0 | |
let atom_to_var x = if is_neg x then not x else x | |
type clause = atom list | |
type problem = { num_vars: int; clauses: clause list } | |
let no_problem: problem = { num_vars = 0; clauses = [] } | |
let add_var (p : problem) = p.num_vars, { p with num_vars = p.num_vars + 1 } | |
let add_clause (p : problem) c = { p with clauses = c :: p.clauses } | |
type soln = Yes | No | IDK | |
exception Unsat (* :( *) | |
let flip_soln = function | |
| Yes -> No | |
| No -> Yes | |
| IDK -> IDK | |
let dpll ({ num_vars; clauses } : problem) (verify : (atom -> soln) -> unit) = | |
let model: soln array = Array.make num_vars IDK in | |
let clauses: clause array = Array.of_list clauses in | |
let num_clauses = Array.length clauses in | |
let watch_clauses: int list array = Array.make (2*num_vars) [] in | |
let watch_literal_1: int array = Array.make num_clauses 0 in | |
let watch_literal_2: int array = Array.make num_clauses 0 in | |
let unit_prop_worklist = ref [] in | |
let gotta_unit_prop x = unit_prop_worklist := x :: !unit_prop_worklist in | |
(* initialize watch literals *) | |
let atom_to_idx a = a + num_vars in | |
Array.iteri (fun i cl -> match cl with | |
| [] -> raise Unsat | |
| [x] -> gotta_unit_prop x | |
| x::y::_ -> | |
let add_clause a = let idx = atom_to_idx a in | |
watch_clauses.(idx) <- i :: watch_clauses.(idx) in | |
add_clause x; add_clause y; | |
watch_literal_1.(i) <- x; | |
watch_literal_2.(i) <- y) clauses; | |
let trail = ref [] in | |
let current_state a = | |
if is_neg a then flip_soln model.(not a) else model.(a) in | |
let rec backtrack_until a = match !trail with | |
| [] -> failwith "impossible -- needs to reach a" | |
| a' :: _ when a' = a -> () | |
| a' :: rest -> | |
trail := rest; | |
model.(if is_neg a' then not a' else a') <- IDK; | |
backtrack_until a in | |
let rec unit_prop_all () = match !unit_prop_worklist with | |
| [] -> () | |
| a :: rest -> unit_prop_worklist := rest; set_to_true a | |
and set_to_true a = match current_state a with | |
| Yes -> () | |
| No -> raise Unsat | |
| IDK -> | |
trail := a :: !trail; | |
(if is_neg a then model.(not a) <- No else model.(a) <- Yes); | |
(* a has just been set to true. Look at (not a)-containing clauses for | |
unit prop opportunities. *) | |
let one_clause i = | |
let clause = clauses.(i) in | |
if List.exists (fun a -> current_state a = Yes) clause then () else | |
match List.filter (fun a -> current_state a = IDK) clause with | |
| [] -> failwith "impossible: a unit should have been propagated" | |
| [x] -> gotta_unit_prop x | |
| x::y::_ -> | |
(* still at least two things left. make one the new watcher *) | |
let old_watcher = not a in | |
let which_array = | |
if watch_literal_1.(i) = old_watcher | |
then watch_literal_1 | |
else watch_literal_2 in | |
assert (which_array.(i) = old_watcher); | |
let new_watcher = if which_array.(i) = x then y else x in | |
which_array.(i) <- new_watcher; | |
watch_clauses.(new_watcher) <- i :: watch_clauses.(new_watcher); | |
watch_clauses.(old_watcher) <- | |
List.filter (fun j -> i <> j) watch_clauses.(old_watcher) in | |
let clauses = watch_clauses.(atom_to_idx (not a)) in | |
List.iter one_clause clauses; | |
unit_prop_all () in | |
(* The main recursive DPLL loop! *) | |
(* This is already a lot of code but still it'd be nice to do CDCL :/ *) | |
let rec go v = | |
(* dumbest possible variable ordering: 0 to N-1, in order *) | |
if v = num_vars then verify current_state else | |
if model.(v) <> IDK then go (v+1) else | |
try | |
set_to_true v; | |
go (v+1) | |
with Unsat -> begin | |
backtrack_until v; | |
set_to_true (not v); | |
go (v+1) | |
end in | |
unit_prop_all (); | |
go 0 | |
end | |
module SMT : sig | |
type atom | |
val not : atom -> atom | |
type clause = atom list | |
type var_id = int | |
type problem | |
val new_problem : int (* number of variables *) -> problem | |
val eq : problem -> var_id -> var_id -> atom | |
val new_bool : problem -> atom | |
val add_clause : problem -> clause -> unit | |
type response = SAT | UNSAT | |
val solve : problem -> response | |
end = struct | |
type atom = SAT.atom | |
let not = SAT.not | |
type clause = SAT.clause | |
type var_id = int | |
type problem = | |
{ num_vars: int | |
; atoms: (var_id * var_id, atom) Hashtbl.t | |
; mutable sat: SAT.problem } | |
let new_problem num_vars: problem = | |
{ num_vars = num_vars | |
; atoms = Hashtbl.create 16 | |
; sat = SAT.no_problem } | |
let eq (p: problem) x y = | |
let pair = if x < y then x, y else y, x in | |
match Hashtbl.find_opt p.atoms pair with | |
| Some a -> a | |
| None -> | |
let a, new_sat = SAT.add_var p.sat in | |
p.sat <- new_sat; | |
Hashtbl.add p.atoms pair a; | |
a | |
let new_bool (p: problem) = | |
let a, new_sat = SAT.add_var p.sat in | |
p.sat <- new_sat; a | |
let add_clause (p: problem) c = | |
p.sat <- SAT.add_clause p.sat c | |
type response = SAT | UNSAT | |
let solve ({ num_vars; atoms; sat }: problem) = | |
let verify model = | |
let parents = Array.init num_vars (fun i -> i) in | |
let rec find i = | |
let p = parents.(i) in | |
if p = i then i else let x = find p in parents.(i) <- x; x in | |
let union i j = parents.(find i) <- find j in | |
Hashtbl.iter (fun (i, j) a -> | |
if model a = SAT.Yes then union i j) atoms; | |
Hashtbl.iter (fun (i, j) a -> | |
if model a = SAT.No && find i = find j then raise SAT.Unsat) atoms in | |
match SAT.dpll sat verify with | |
| () -> SAT | |
| exception SAT.Unsat -> UNSAT | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment