Skip to content

Instantly share code, notes, and snippets.

@allanlw
Created June 3, 2019 20:50
Show Gist options
  • Save allanlw/86ad5d3f91aac1a82c18b9c9ce08575c to your computer and use it in GitHub Desktop.
Save allanlw/86ad5d3f91aac1a82c18b9c9ce08575c to your computer and use it in GitHub Desktop.
bitfield.py
# Copyright 2019, Akamai Technologies, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import itertools
import math
import sympy
import tabulate
import numpy
import logging
import string
# This is a class for managing the encoding of a combinatoric object
# into a sympy space
class CombinatoricBitfield(object):
def __init__(self, names, values, value_names, onehot = False):
self.onehot = onehot
self.field_names = list(names)
self.field_values = map(list, values)
self.field_value_names = map(list, value_names)
self.field_bitwidths = self._bitwidths()
self.field_symbols = list(self._symbols())
self.symbols = list(itertools.chain.from_iterable(self.field_symbols))
assert len(self.symbols) == sum(self.field_bitwidths)
# ceil(log_2(x))
def _ceillg2(self, x):
return int(math.ceil(math.log(x) / math.log(2)))
# Returns a list corresponding to the number of bits used in the bitvector for each field
def _bitwidths(self):
if self.onehot:
return [(x if x != 1 else 0) for x in map(len, self.field_values)]
else:
return map(self._ceillg2, map(len, self.field_values))
# returns a list of lists corresponding to the symbol names for each field
# These names are just for debugging
def _symbol_names(self):
total = 0
for i, w in enumerate(self.field_bitwidths):
if w == 0:
continue
elif w == 1:
yield [self.field_names[i]]
else:
yield [self.field_names[i] + "_" + str(j) for j in reversed(range(w))]
def _symbols(self):
for x in self._symbol_names():
yield map(sympy.Symbol, x)
# This is the same as field_values but pads to a all the possible values for the corresponding bitwidth
def _field_values_padded(self):
for v, b in zip(self.field_values, self.field_bitwidths):
if self.onehot:
if b == 0:
yield v
continue
q = v[:]
res = []
for i in range(2 ** b):
if string.count(bin(i), "1") == 1:
res.append(q[0])
q = q[1:]
else:
res.append(None)
yield res
else:
yield v + [None] * (2 ** b - len(v))
# Returns a sympy expression that corresponds to TRUE IFF the bitmask matches these indexes
def _indexes_to_bitv(self, indexes):
assert len(indexes) == len(self.field_bitwidths)
res = []
for i, l in zip(indexes, self.field_bitwidths):
if l == 0: continue
b = bin(i)[2:].zfill(l)
for bit in b:
res.append(int(bit))
assert len(res) == len(self.symbols)
return res
# Takes some boolean predicate, that expects a list of values corresponding to the fields
# taking on that value.
def compute_predicate(self, predicate):
minterms = []
dont_cares = []
for option in itertools.product(*map(enumerate, self._field_values_padded())):
indexes = [x[0] for x in option]
values = [x[1] for x in option]
bitv = self._indexes_to_bitv(indexes)
# this entry in our table is undefined
if None in values:
dont_cares.append(bitv)
continue
names = [self.field_value_names[i][j] for i,j in enumerate(indexes)]
result = predicate(values, names)
if not result: continue
minterms.append(bitv)
logging.info("Symbols: {0}".format(self.symbols))
logging.info("Minterms: {0}".format(minterms))
logging.info("Dont cares: {0}".format(dont_cares))
predicate = sympy.SOPform(self.symbols, minterms, dont_cares)
logging.info("Done running quine-mccluskey: {0!r}".format(predicate))
return predicate
# Given a list containing None, True, False for each possible bit in bitwidths
# corresponding to the possible values for each FIELD that match that bitmask (None is unspecified, True or False bits must match)
# yield a list of len(FIELDS) such that each element contains:
# None if that field is unspecified
# the name of the field value if the field is specified
def _get_values_for(self, values):
possibilities = []
x = values
for i, w in enumerate(self.field_bitwidths):
term = x[0:w]
x = x[w:]
# This can only happen when onehot == false, if a field has only one value
if w == 0:
possibilities.append([None])
continue
bitmask = int("".join("1" if x in (True, False) else "0" for x in term),2)
# No opinions
if bitmask == 0 or w == 0:
possibilities.append([None])
continue
bitv = int("".join("1" if x else "0" for x in term),2)
if self.onehot:
options = [z for j, z in enumerate(self.field_value_names[i]) if (1 << j) & bitmask == bitv]
else:
options = [z for j, z in enumerate(self.field_value_names[i]) if j & bitmask == bitv]
possibilities.append(options)
for res in itertools.product(*possibilities):
yield res
# a, b are both lists of either str or None specifiying an exact value or a wildcard
def _is_subset_of(self, a, b):
return all(x == y or x is None for (x, y) in zip(a, b))
# Take a predicate in SOP form (e.g and OR of ANDs)
# and a list of the symbols in the predicate
# returns a table where the rows correspond to the symbols
# and the columns correspond to the AND clauses
# the values in each row correspond to the names of the matching FIELD values
# The table is returned in Latex format
# excluded_rows contains values from NAMES that correspond to rows to not include in the table
def predicate_table(self, v, excluded_rows=()):
v = sympy.to_dnf(v, simplify=True)
all_conjunctions = []
if v == sympy.false:
ors = []
all_conjunctions = [tuple(["-" for _ in self.field_names])]
elif v == sympy.true:
ors = []
all_conjunctions = [tuple([None for _ in self.field_names])]
elif v.func == sympy.Or:
ors = v.args
else:
ors = [v]
for conjunction in ors:
if conjunction.func == sympy.And:
ands = conjunction.args
else:
ands = [conjunction]
unsatisfiable = False
values = [None] * sum(self.field_bitwidths)
for term in ands:
if term == sympy.true:
continue
elif term == sympy.false:
unsatisfiable = True
break
if term.func == sympy.Not:
idx = self.symbols.index(term.args[0])
else:
idx = self.symbols.index(term)
assert values[idx] == None
values[idx] = term.func != sympy.Not
if not unsatisfiable:
all_conjunctions += list(self._get_values_for(values))
all_conjunctions = list(sorted(sorted(set(all_conjunctions)), key=lambda x: -x.count(None)))
printed_rows = []
for row in all_conjunctions:
if all(not self._is_subset_of(a, row) for a in printed_rows):
printed_rows.append(row)
rows = self.field_names
cols = printed_rows
table = []
for i, r in enumerate(rows):
if r in excluded_rows: continue
vals = [x[i] for x in cols]
if all(x is None for x in vals): continue
vals = [x if x is not None else "*" for x in vals]
table.append([r] + vals)
x = numpy.array(table, numpy.object)
table = numpy.transpose(x).tolist()
res = tabulate.tabulate(table, tablefmt="latex_booktabs", headers="firstrow")
res = res.replace("*", r"\textit{*}")
return res
# Takes the symbols (from bittable_names) and an "assumption" and returns
# a Sympy predicate that corresponds to that assumption being true.
# Assumptions is a LIST with items that are
# tuple (NAME, VALUE) where NAME is a name in NAMES and VALUE is a possible value for that field
def assumption_to_predicate(self, predicate, assumption):
clauses = []
for fieldname, fieldvalue in assumption:
field_index = self.field_names.index(fieldname)
field_possibility_index = self.field_value_names[field_index].index(fieldvalue)
symbols = self.field_symbols[field_index]
if self.onehot:
for i, sym in enumerate(symbols):
clauses.append(sym if i == field_possibility_index else sympy.Not(sym))
else:
bits = bin(field_possibility_index)[2:].zfill(len(symbols))
for sym, bit in zip(symbols, bits):
clauses.append(sym if bit == "1" else sympy.Not(sym))
assumption = sympy.And(*clauses)
logging.info(assumption)
res = sympy.And(predicate, assumption)
res = sympy.to_dnf(res, simplify=True)
return res
def test():
b = CombinatoricBitfield(["fielda", "fieldb"], [range(3), range(7)], [map(str, range(3)), map(str, range(7))])
logging.info("Starting bitfield tests...")
z = b.compute_predicate(lambda x,_: (x[0] + x[1]) % 2 == 1)
logging.info(repr(z))
for y in range(7):
logging.info("Assuming b=" + repr(y))
logging.info("Full predicate=" +repr(b.assumption_to_predicate(z, [("fieldb", str(y))])))
logging.info("Done testing")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment