Created
November 8, 2024 03:15
-
-
Save kmicinski/720aa6a0dd59835110b12215826d59e5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::{HashMap, HashSet}; | |
type Clause = Vec<i32>; | |
type Clauses = HashSet<Clause>; | |
type Trail = Vec<TrailEntry>; | |
type Assignment = HashMap<i32, bool>; | |
#[derive(Debug, Clone)] | |
enum TrailEntry { | |
Decision(i32), | |
Propagated(i32), | |
} | |
#[derive(Debug, PartialEq)] | |
enum ClauseStatus { | |
Satisfied, | |
Unit, | |
Conflicting, | |
Unassigned, | |
} | |
#[derive(Debug, Clone)] | |
enum DpllState { | |
Running(Trail, Assignment, Clauses), | |
Sat(Trail), | |
Unsat, | |
} | |
fn classify_clause(clause: &Clause, assignment: &Assignment) -> ClauseStatus { | |
let (mut pos, mut neg) = (0, 0); | |
for &lit in clause { | |
match assignment.get(&lit.abs()) { | |
Some(&value) => { | |
if value { | |
if lit < 0 { | |
neg += 1; | |
} else { | |
pos += 1; | |
} | |
} else { | |
if lit < 0 { | |
pos += 1; | |
} else { | |
neg += 1; | |
} | |
} | |
} | |
None => {} | |
} | |
} | |
if pos > 0 { | |
ClauseStatus::Satisfied | |
} else if (clause.len() as i32 - neg) == 1 { | |
ClauseStatus::Unit | |
} else if clause.len() as i32 == neg { | |
ClauseStatus::Conflicting | |
} else { | |
ClauseStatus::Unassigned | |
} | |
} | |
fn has_unit_clause(clauses: &Clauses, assignment: &Assignment) -> Option<Clause> { | |
clauses.iter().find(|cl| classify_clause(cl, assignment) == ClauseStatus::Unit).cloned() | |
} | |
fn find_first_unassigned_lit(clause: &Clause, assignment: &Assignment) -> Option<i32> { | |
clause.iter().find(|&&lit| !assignment.contains_key(&lit.abs())).cloned() | |
} | |
fn step(state: DpllState) -> DpllState { | |
match state { | |
DpllState::Running(mut trail, mut assignment, clauses) => { | |
if let Some(unit_clause) = has_unit_clause(&clauses, &assignment) { | |
if let Some(ulit) = find_first_unassigned_lit(&unit_clause, &assignment) { | |
trail.push(TrailEntry::Propagated(ulit)); | |
assignment.insert(ulit.abs(), ulit > 0); | |
return DpllState::Running(trail, assignment, clauses); | |
} | |
} | |
if clauses.iter().any(|cl| classify_clause(cl, &assignment) == ClauseStatus::Unassigned) { | |
if let Some(unassigned_clause) = clauses.iter().find(|cl| classify_clause(cl, &assignment) == ClauseStatus::Unassigned) { | |
if let Some(dlit) = find_first_unassigned_lit(unassigned_clause, &assignment) { | |
trail.push(TrailEntry::Decision(dlit)); | |
assignment.insert(dlit.abs(), dlit > 0); | |
return DpllState::Running(trail, assignment, clauses); | |
} | |
} | |
} | |
if clauses.iter().any(|cl| classify_clause(cl, &assignment) == ClauseStatus::Conflicting) { | |
if trail.iter().all(|entry| !matches!(entry, TrailEntry::Decision(_))) { | |
return DpllState::Unsat; | |
} else { | |
let mut new_trail = trail.clone(); | |
let mut new_assignment = assignment.clone(); | |
while let Some(entry) = new_trail.pop() { | |
match entry { | |
TrailEntry::Decision(lit) => { | |
new_trail.push(TrailEntry::Propagated(-lit)); | |
new_assignment.insert(lit.abs(), !(lit > 0)); | |
return DpllState::Running(new_trail, new_assignment, clauses); | |
} | |
TrailEntry::Propagated(lit) => { | |
new_assignment.remove(&lit.abs()); | |
} | |
} | |
} | |
} | |
} | |
DpllState::Sat(trail) | |
} | |
other => other, | |
} | |
} | |
fn run(clauses: &Clauses) { | |
let mut state = DpllState::Running(Vec::new(), Assignment::new(), clauses.clone()); | |
loop { | |
match &state { | |
DpllState::Unsat => { | |
println!("UNSAT"); | |
break; | |
} | |
DpllState::Sat(trail) => { | |
println!("SAT {:?}", trail); | |
// Generate a complete assignment | |
let max_var = clauses.iter().flat_map(|cl| cl.iter()).map(|&x| x.abs()).max().unwrap_or(0); | |
let mut complete_assignment: Assignment = trail.iter().filter_map(|entry| { | |
match entry { | |
TrailEntry::Decision(lit) | TrailEntry::Propagated(lit) => Some((lit.abs(), *lit > 0)), | |
} | |
}).collect(); | |
// Print all variables up to the max_var | |
for var in 1..=max_var { | |
let value = complete_assignment.get(&var).cloned().unwrap_or(false); | |
println!("Variable {}: {}", var, value); | |
} | |
break; | |
} | |
_ => { | |
println!("{:?}", state); | |
state = step(state); | |
println!("⇒"); | |
} | |
} | |
} | |
} | |
fn main() { | |
let clauses: Clauses = [ | |
vec![1, 2, 3], | |
vec![-1, 2], | |
vec![-2, 3, 4], | |
vec![-3, -4, 5], | |
vec![1, -5], | |
vec![-1, -3, -5], | |
vec![2, 4, -5], | |
vec![-1, -2, -3, -4, -5], | |
] | |
.into_iter() | |
.collect(); | |
run(&clauses); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment