Last active
August 29, 2015 14:10
-
-
Save bonzini/7942680bf9818938a259 to your computer and use it in GitHub Desktop.
BDD library in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2014 Paolo Bonzini | |
# API inspired by py-simple-bdd. | |
# License: X11 (MIT) | |
class Variable(object): | |
def __init__(self, name): | |
self._name = name | |
self._hash = hash(name) | |
self._node = Node(self, Node.T, Node.F) | |
def __str__(self): | |
return str(self._name) | |
def __repr__(self): | |
return __name__+ '.Variable(' + repr(self._name) + ')' | |
def __hash__(self): | |
return self._hash | |
def __eq__(self, other): | |
return (self is other or | |
(isinstance(other, Variable) and | |
self._hash == other._hash and | |
self._name == other._name)) | |
def __ne__(self, other): | |
return not self.__eq__(other) | |
@property | |
def name(self): | |
return self._name | |
@property | |
def node(self): | |
return self._node | |
class Ordering(object): | |
def __init__(self, vars=[]): | |
self._vars = [] | |
self._order = dict() | |
self._n = 0 | |
self.extend(vars) | |
def extend(self, vars): | |
for i in vars: | |
self._order[i] = self._n | |
self._n = self._n + 1 | |
def __len__(self): | |
return self._n | |
def __getitem__(self, var): | |
if var is None: | |
return self._n | |
else: | |
return self._order[var] | |
def sort(self, vars): | |
return sorted(vars, key = lambda x: self._order[x]) | |
@property | |
def vars(self): | |
return self._vars | |
@property | |
def comparator(self): | |
"""Return a two-argument function that takes two variables and | |
returns True if the first is above the second in the BDD.""" | |
return (lambda x, y: | |
not (x is None) and | |
((y is None) or | |
self._order[x] < self._order[y])) | |
class _AbstractNode(object): | |
def visitLazy(self, compose, t=None, f=None): | |
cache = dict() | |
if not (t is None): | |
cache[id(Node.T)] = t | |
if not (f is None): | |
cache[id(Node.F)] = f | |
def v(p): | |
if id(p) in cache: | |
result = cache[id(p)] | |
else: | |
result = cache[id(p)] = \ | |
compose(p, lambda: v(p.t), lambda: v(p.f)) | |
return result | |
return v(self) | |
def visit(self, compose, t, f): | |
cache = dict() | |
cache[id(Node.T)] = t | |
cache[id(Node.F)] = f | |
def v(p): | |
if id(p) in cache: | |
result = cache[id(p)] | |
else: | |
result = cache[id(p)] = compose(p, v(p.t), v(p.f)) | |
return result | |
return v(self) | |
def countNodes(self): | |
"""Returns the number of distinct node objects in the BDD""" | |
def compose(p, vt, vf): | |
compose.n = compose.n + 1 | |
if not Node.isTerminal(p): | |
vt() | |
vf() | |
return compose.n | |
compose.n = 0 | |
return self.visitLazy(compose) | |
@property | |
def root(self): | |
return self | |
class Node(_AbstractNode): | |
@staticmethod | |
def isTerminal(p): | |
"""Tests if p is Node.T or Node.F""" | |
return not isinstance(p, Node) | |
class __TerminalNode(_AbstractNode): | |
def __repr__(self): | |
return __name__ + '.Node.' + str(self) | |
@property | |
def var(self): | |
return None | |
@property | |
def t(self): | |
return self | |
@property | |
def f(self): | |
return self | |
class __TrueNode(__TerminalNode): | |
def __str__(self): | |
return 'T' | |
def __hash__(self): | |
return 1 | |
def __invert__(self): | |
return Node.F | |
class __FalseNode(__TerminalNode): | |
def __str__(self): | |
return 'F' | |
def __hash__(self): | |
return 0 | |
def __invert__(self): | |
return Node.T | |
T = __TrueNode() | |
F = __FalseNode() | |
@staticmethod | |
def terminal(value): | |
"""Returns the terminal corresponding to the boolean interpretation of value""" | |
if value and value != Node.F: | |
return Node.T | |
else: | |
return Node.F | |
def __init__(self, var, t, f, negated = None): | |
self._hash = (hash(var) + hash(t) - hash(f)) % 0xFFFFFFFF | |
self._var = var | |
self._t = t | |
self._f = f | |
self._negated = negated or Node(var, ~t, ~f, self) | |
@property | |
def var(self): | |
return self._var | |
@property | |
def t(self): | |
return self._t | |
@property | |
def f(self): | |
return self._f | |
def __str__(self): | |
if self.t == Node.T and self.f == Node.F: | |
return str(self.var) | |
elif self.t == Node.F and self.f == Node.T: | |
return str(self.var)+"'" | |
elif self.t == Node.T: | |
return '(%s | %s)' % (str(self.var), str(self.f)) | |
elif self.f == Node.F: | |
return '(%s & %s)' % (str(self.var), str(self.t)) | |
else: | |
return '(%s ? %s : %s)' % (str(self.var), str(self.t), str(self.f)) | |
def __repr__(self): | |
return '%s.Node(%s, %s, %s)' % \ | |
(__name__ , repr(self.var), repr(self.t), repr(self.f)) | |
def __hash__(self): | |
return self._hash | |
def __eq__(self, other): | |
return (self is other or | |
(isinstance(other, Node) and | |
self._hash == other._hash and | |
self.var == other.var and | |
self.t == other.t and | |
self.f == other.f)) | |
def __ne__(self, other): | |
return not self.__eq__(other) | |
def __invert__(self): | |
return self._negated | |
class BDD(object): | |
class __Cache(dict): | |
def __init__(self, ordering): | |
self[Node.T] = Node.T | |
self[Node.F] = Node.F | |
self._above = ordering.comparator | |
def replace(self, t, f, p, q=None): | |
if t == f: | |
result = t | |
elif t is p.t and f is p.f: | |
result = p | |
elif not (q is None) and t is q.t and f is q.f: | |
result = q | |
else: | |
result = bddNode(p.var, t, f) | |
if result in self: | |
return self[result] | |
elif ~result in self: | |
return ~self[~result] | |
self[result] = result | |
return result | |
def evaluate(self, p, vars, i, assignments, memo): | |
if i >= len(vars) or self._above(vars[i], p.var): | |
return p | |
tup = (p, i) | |
if tup in memo: | |
return memo[tup] | |
if p.var == vars[i]: | |
if Node.terminal(assignments[p.var]) == Node.T: | |
result = self.evaluate(p.t, vars, i + 1, assignments, memo) | |
else: | |
result = self.evaluate(p.f, vars, i + 1, assignments, memo) | |
else: | |
t = self.evaluate(p.t, vars, i, assignments, memo) | |
f = self.evaluate(p.f, vars, i, assignments, memo) | |
result = self.replace(t, f, p) | |
memo[tup] = result | |
return result | |
def reduce(self, p): | |
t = Node.isTerminal(p.t) and p.t or self.reduce(p.t) | |
f = Node.isTerminal(p.f) and p.f or self.reduce(p.f) | |
return self.replace(t, f, p) | |
def apply(self, p, q, binop, memo): | |
easy = binop(p, q) | |
if not (easy is None): | |
return easy | |
tup = (p, q) | |
if tup in memo: | |
return memo[tup] | |
if p.var == q.var: | |
t = self.apply(p.t, q.t, binop, memo) | |
f = self.apply(p.f, q.f, binop, memo) | |
result = self.replace(t, f, p, q) | |
elif self._above(p.var, q.var): | |
t = self.apply(p.t, q, binop, memo) | |
f = self.apply(p.f, q, binop, memo) | |
result = self.replace(t, f, p) | |
else: | |
t = self.apply(p, q.t, binop, memo) | |
f = self.apply(p, q.f, binop, memo) | |
result = self.replace(t, f, q) | |
memo[tup] = result | |
return result | |
def __init__(self, root, ordering, cache=None, negated=None): | |
self._root = root | |
self._ordering = ordering | |
self._cache = cache or BDD.__Cache(self._ordering) | |
self._negated = negated or BDD(~root, ordering, self._cache, self) | |
@staticmethod | |
def conjunction(vars, ordering, cache=None): | |
"""Form a bdd that is the AND of all the variables.""" | |
result = Node.T | |
for i in reversed(ordering.sort(vars)): | |
result = Node(i, result, Node.F) | |
return BDD(result, ordering, cache) | |
@staticmethod | |
def disjunction(vars, ordering, cache=None): | |
"""Form a bdd that is the OR of all the variables.""" | |
result = Node.F | |
for i in reversed(ordering.sort(vars)): | |
result = Node(i, Node.T, result) | |
return BDD(result, ordering, cache) | |
@staticmethod | |
def andAll(bdds, ordering, cache=None): | |
n = 1 | |
stack = [BDD(Node.T, ordering, cache)] | |
for i in bdds: | |
n = n + 1 | |
m = n & -n | |
while m != 1: | |
i = stack.pop() & i | |
m >>= 1 | |
stack.append(i) | |
top = Node.T | |
while len(stack) > 0: | |
top &= stack.pop() | |
return top | |
@staticmethod | |
def orAll(bdds, ordering, cache=None): | |
n = 1 | |
stack = [BDD(Node.F, ordering, cache)] | |
for i in bdds: | |
n = n + 1 | |
m = n & -n | |
while m != 1: | |
i = stack.pop() | i | |
m >>= 1 | |
stack.append(i) | |
top = Node.F | |
while len(stack) > 0: | |
top |= stack.pop() | |
return top | |
def countTrue(self): | |
# Each missing variable doubles the number of assignments | |
f = lambda p, vt, vf: \ | |
(vt << (self._ordering[p.t.var] - self._ordering[p.var]) - 1) + \ | |
(vf << (self._ordering[p.f.var] - self._ordering[p.var]) - 1) | |
return self.visit(f, 1, 0) << self._ordering[self._root.var] | |
def evaluate(self, assignments): | |
vars = self._ordering.sort(assignments) | |
node = self._cache.evaluate(self._root, vars, 0, assignments, dict()) | |
return BDD(node, self._ordering, self._cache) | |
def __xnor(self, a, b): | |
return BDD(a.root, self._ordering, self._cache) ^ ~b | |
def force(self, assignments): | |
vars = self._ordering.sort(assignments) | |
iffs = (self.__xnor(v.node, assignments[v]) for v in vars) | |
return self & BDD.andAll(iffs, self._ordering, self._cache) | |
def apply(self, binop, other): | |
node = self._cache.apply(self._root, other.root, binop, dict()) | |
return BDD(node, self._ordering, self._cache) | |
def reduce(self): | |
node = self._cache.reduce(self._root) | |
return BDD(node, self._ordering, self._cache) | |
def visitLazy(self, compose, t, f): | |
return self._root.visitLazy(compose, t, f) | |
def visit(self, compose, t, f): | |
return self._root.visit(compose, t, f) | |
@property | |
def root(self): | |
return self._root | |
@property | |
def ordering(self): | |
return self._ordering | |
def __str__(self): | |
return str(self._root) | |
def __hash__(self): | |
return hash(self._root) | |
def __eq__(self, other): | |
return (self is other or | |
(isinstance(other, BDD) and | |
self._root == other._root and | |
self._ordering is other._ordering)) | |
def __ne__(self, other): | |
return not self.__eq__(other) | |
def __invert__(self): | |
return self._negated | |
@staticmethod | |
def __AND(p, q): | |
if p == Node.T: | |
return q | |
if q == Node.T: | |
return p | |
if p == Node.F: | |
return p | |
if q == Node.F: | |
return q | |
return None | |
def __and__(self, other): | |
return self.apply(BDD.__AND, other) | |
def __rand__(self, other): | |
return self & other | |
@staticmethod | |
def __OR(p, q): | |
if p == Node.F: | |
return q | |
if q == Node.F: | |
return p | |
if p == Node.T: | |
return p | |
if q == Node.T: | |
return q | |
return None | |
def __or__(self, other): | |
return self.apply(BDD.__OR, other) | |
def __ror__(self, other): | |
return self | other | |
@staticmethod | |
def __XOR(p, q): | |
if p == Node.T: | |
return ~q | |
if q == Node.T: | |
return ~p | |
if p == Node.F: | |
return q | |
if q == Node.F: | |
return p | |
return None | |
def __xor__(self, other): | |
return self.apply(BDD.__XOR, other) | |
def __rxor__(self, other): | |
return self ^ other | |
def bddNode(var, t, f): | |
if t.var == var: | |
t = t.t | |
if f.var == var: | |
f = f.f | |
if t == f: | |
return t | |
return Node(var, t, f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2011 Craig Eales | |
# Copyright 2014 Paolo Bonzini | |
# Based on the py-simple-bdd unit tests | |
# This file is free software: you can redistribute it and/or modify | |
# it under the terms of the GNU General Public License as published by | |
# the Free Software Foundation, either version 3 of the License, or | |
# (at your option) any later version. | |
# This file is distributed in the hope that it will be useful, | |
# but WITHOUT ANY WARRANTY; without even the implied warranty of | |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
# GNU General Public License for more details. | |
# You should have received a copy of the GNU General Public License | |
# along with this file. If not, see <http://www.gnu.org/licenses/>. | |
from bdd import Variable, Node, BDD, bddNode, Ordering | |
import bdd # for eval | |
import unittest | |
x = Variable('x') | |
y = Variable('y') | |
z = Variable('z') | |
w = Variable('w') | |
u = Variable('u') | |
ordering = Ordering([x, y, z, w, u]) | |
T = Node.T | |
F = Node.F | |
def as_bdd(n): | |
return BDD(n, ordering) | |
class TestNode(unittest.TestCase): | |
def setUp(self): | |
pass | |
def tearDown(self): | |
pass | |
def testTrue(self): | |
self.assertEqual(T, T) | |
self.assertNotEqual(T, F) | |
self.assertEqual(T, eval(repr(T))) | |
self.assertEqual(as_bdd(T).evaluate({x:False}).root, T) | |
self.assertEqual(as_bdd(T).evaluate({x:True}).root, T) | |
self.assertTrue(Node.isTerminal(T)) | |
self.assertEqual(Node.terminal(True), T) | |
self.assertEqual(Node.terminal([x]), T) | |
self.assertEqual(~T, F) | |
def testFalse(self): | |
self.assertEqual(F, F) | |
self.assertNotEqual(F, T) | |
self.assertEqual(F, eval(repr(F))) | |
self.assertEqual(as_bdd(F).evaluate({x:False}).root, F) | |
self.assertEqual(as_bdd(F).evaluate({x:True}).root, F) | |
self.assertTrue(Node.isTerminal(F)) | |
self.assertEqual(Node.terminal(False), F) | |
self.assertEqual(Node.terminal([]), F) | |
self.assertEqual(~F, T) | |
def testSingleNode(self): | |
n1 = bddNode(x, T, F) | |
self.assertFalse(Node.isTerminal(n1)) | |
self.assertEqual(n1.countNodes(), 3) | |
self.assertEqual(n1, n1) | |
self.assertEqual(n1, eval(repr(n1))) | |
self.assertEqual(hash(n1), hash(eval(repr(n1)))) | |
self.assertEqual(as_bdd(n1).evaluate({x:True}).root, T) | |
self.assertEqual(as_bdd(n1).evaluate({x:False}).root, F) | |
self.assertEqual(as_bdd(n1).evaluate({y:True}).root, n1) | |
n2 = bddNode(y, T, F) | |
self.assertNotEqual(n1, n2) | |
self.assertEqual(n2.countNodes(), 3) | |
n3 = bddNode(x, T, T) | |
self.assertNotEqual(n1, n3) | |
self.assertEqual(n3, T) | |
self.assertEqual(n3.countNodes(), 1) | |
n4 = bddNode(x, F, T) | |
self.assertNotEqual(n1, n4) | |
self.assertEqual(n4.countNodes(), 3) | |
self.assertEqual(~n1, n4) | |
self.assertEqual(~n4, n1) | |
n5 = bddNode(x, F, F) | |
self.assertNotEqual(n1, n5) | |
self.assertEqual(n5, F) | |
self.assertEqual(n5.countNodes(), 1) | |
def testNestedNode(self): | |
n1 = bddNode(z, T, F) | |
n2 = bddNode(z, F, T) | |
n3 = bddNode(z, T, F) | |
cn1 = bddNode(y, n1, n2) | |
cn2 = bddNode(y, n2, n1) | |
cn3 = bddNode(y, n3, n2) | |
self.assertEqual(cn1, cn1) | |
self.assertNotEqual(cn1, cn2) | |
self.assertEqual(cn1, cn3) | |
self.assertEqual(cn1.countNodes(), 5) | |
self.assertEqual(hash(cn1), hash(cn3)) | |
self.assertEqual(cn1, eval(repr(cn1))) | |
bdd1 = as_bdd(cn1) | |
self.assertEqual(bdd1.evaluate({y:True}).root, n1) | |
self.assertEqual(bdd1.evaluate({y:False}).root, n2) | |
self.assertEqual(bdd1.evaluate({z:True}).evaluate({y:True}).root, T) | |
self.assertEqual(bdd1.evaluate({z:True}).evaluate({y:False}).root, F) | |
self.assertEqual(bdd1.evaluate({z:False}).evaluate({y:True}).root, F) | |
self.assertEqual(bdd1.evaluate({z:False}).evaluate({y:False}).root, T) | |
self.assertEqual(bdd1.evaluate({y:True}).evaluate({z:True}).root, T) | |
self.assertEqual(bdd1.evaluate({y:False}).evaluate({z:True}).root, F) | |
self.assertEqual(bdd1.evaluate({y:True}).evaluate({z:False}).root, F) | |
self.assertEqual(bdd1.evaluate({y:False}).evaluate({z:False}).root, T) | |
self.assertEqual(bdd1.evaluate({z:True, y:True}).root, T) | |
self.assertEqual(bdd1.evaluate({z:True, y:False}).root, F) | |
self.assertEqual(bdd1.evaluate({z:False, y:True}).root, F) | |
self.assertEqual(bdd1.evaluate({z:False, y:False}).root, T) | |
cn4 = bddNode(z, bddNode(z, T, F), bddNode(z, T, F)) | |
bdd4 = as_bdd(cn4) | |
self.assertEqual(bdd4.evaluate({z:True}).root, T) | |
self.assertEqual(bdd4.evaluate({z:False}).root, F) | |
self.assertEqual(~bdd4.evaluate({z:True}).root, F) | |
self.assertEqual(~bdd4.evaluate({z:False}).root, T) | |
self.assertEqual(cn4.countNodes(), 3) | |
self.assertEqual(~n1, n2) | |
self.assertEqual(~n2, n1) | |
self.assertEqual(~~cn1, cn1) | |
self.assertNotEqual(~cn1, cn1) | |
def testOrderings(self): | |
above = ordering.comparator | |
self.assertTrue(above(y, w)) | |
self.assertFalse(above(z, y)) | |
self.assertTrue(ordering.sort([y, w]) == [y, w]) | |
self.assertTrue(ordering.sort([w, y]) == [y, w]) | |
self.assertTrue(ordering.sort([x, w, y]) == [x, y, w]) | |
def testSimple(self): | |
self.assertEqual(x.node, bddNode(x, T, F)) | |
self.assertEqual(~(x.node), bddNode(x, F, T)) | |
def testConjunction(self): | |
self.assertEqual(BDD.conjunction([], ordering).root, T) | |
self.assertEqual(BDD.conjunction([x], ordering).root, x.node) | |
self.assertEqual(BDD.conjunction([x, y], ordering).root, bddNode(x, y.node, F)) | |
self.assertEqual(BDD.conjunction([x, y, z], ordering).root, bddNode(x, bddNode(y, z.node, F), F)) | |
def testDisjunction(self): | |
self.assertEqual(BDD.disjunction([], ordering).root, F) | |
self.assertEqual(BDD.disjunction([x], ordering).root, x.node) | |
self.assertEqual(BDD.disjunction([x, y], ordering).root, bddNode(x, T, y.node)) | |
self.assertEqual(BDD.disjunction([x, y, z], ordering).root, bddNode(x, T, bddNode(y, T, z.node))) | |
def testOr(self): | |
bddx = as_bdd(x.node) | |
bddy = as_bdd(y.node) | |
n1 = BDD.disjunction([x, y], ordering) | |
n2 = bddx | y.node | |
self.assertEqual(n1, n2) | |
n3 = bddx | bddy | |
self.assertEqual(n1, n3) | |
n4 = x.node | bddy | |
self.assertEqual(n1, n4) | |
def testAnd(self): | |
bddx = as_bdd(x.node) | |
bddy = as_bdd(y.node) | |
n1 = BDD.conjunction([x, y], ordering) | |
n2 = bddx & y.node | |
self.assertEqual(n1, n2) | |
n3 = bddx & bddy | |
self.assertEqual(n1, n3) | |
n4 = x.node & bddy | |
self.assertEqual(n1, n4) | |
def testComplexApply(self): | |
bddx = as_bdd(x.node) | |
bddy = as_bdd(y.node) | |
xnor = (bddx | ~bddy) & (bddy | ~bddx) | |
xor = (bddx & ~bddy) | (bddy & ~bddx) | |
self.assertEqual(xnor.evaluate({x:True,y:True}).root, T) | |
self.assertEqual(xnor.evaluate({x:True,y:False}).root, F) | |
self.assertEqual(xnor.evaluate({x:False,y:True}).root, F) | |
self.assertEqual(xnor.evaluate({x:False,y:False}).root, T) | |
self.assertEqual(xor.evaluate({x:True,y:True}).root, F) | |
self.assertEqual(xor.evaluate({x:True,y:False}).root, T) | |
self.assertEqual(xor.evaluate({x:False,y:True}).root, T) | |
self.assertEqual(xor.evaluate({x:False,y:False}).root, F) | |
self.assertEqual(xor, ~xnor) | |
self.assertEqual((xor | xnor).root, T) | |
self.assertEqual((xor & xnor).root, F) | |
def testCountTrue(self): | |
self.assertEqual(as_bdd(Node.T).countTrue(), 1 << len(ordering)) | |
self.assertEqual(as_bdd(Node.F).countTrue(), 0) | |
bdd1 = BDD.conjunction([x, y, z, w, u], ordering) | |
self.assertEqual(bdd1.countTrue(), 1) | |
bdd2 = BDD.disjunction([x, y, z, w, u], ordering) | |
self.assertEqual(bdd2.countTrue(), (1 << len(ordering)) - 1) | |
def testAndAll(self): | |
bdd1 = BDD.conjunction([x, y, z, w], ordering) | |
bdd2 = BDD.andAll((as_bdd(v.node) for v in [x, y, z, w]), ordering) | |
self.assertEqual(bdd1, bdd2) | |
def testOrAll(self): | |
bdd1 = BDD.disjunction([x, y, z, w], ordering) | |
bdd2 = BDD.orAll((as_bdd(v.node) for v in [x, y, z, w]), ordering) | |
self.assertEqual(bdd1, bdd2) | |
def testForce(self): | |
bdd1 = (as_bdd(x.node) & as_bdd(w.node)) | as_bdd(u.node) | |
bdd2 = (as_bdd(y.node) | as_bdd(z.node)) | |
bdd3 = (as_bdd(y.node) & as_bdd(z.node)) | |
bdd = bdd1.force({w: bdd2, u: bdd3}) | |
self.assertEqual(bdd.countTrue(), | |
((as_bdd(x.node) & bdd2) | bdd3).countTrue() >> 2) | |
def testReduce(self): | |
z1 = bdd.Node(z, Node.T, Node.F) | |
z2 = bdd.Node(z, Node.T, Node.F) | |
redundant = bdd.Node(x, z1, z2) | |
self.assertEqual(redundant.countNodes(), 5) | |
reduced = as_bdd(redundant).reduce() | |
self.assertEqual(reduced.root, z.node) | |
self.assertEqual(reduced.root.countNodes(), 3) | |
if __name__ == '__main__': | |
unittest.main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment