Created
June 3, 2019 20:50
-
-
Save allanlw/86ad5d3f91aac1a82c18b9c9ce08575c to your computer and use it in GitHub Desktop.
bitfield.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2019, Akamai Technologies, Inc. | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining | |
# a copy of this software and associated documentation files (the | |
# "Software"), to deal in the Software without restriction, including | |
# without limitation the rights to use, copy, modify, merge, publish, | |
# distribute, sublicense, and/or sell copies of the Software, and to | |
# permit persons to whom the Software is furnished to do so, subject to | |
# the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be | |
# included in all copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, | |
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | |
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND | |
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE | |
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION | |
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION | |
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |
import itertools | |
import math | |
import sympy | |
import tabulate | |
import numpy | |
import logging | |
import string | |
# This is a class for managing the encoding of a combinatoric object | |
# into a sympy space | |
class CombinatoricBitfield(object): | |
def __init__(self, names, values, value_names, onehot = False): | |
self.onehot = onehot | |
self.field_names = list(names) | |
self.field_values = map(list, values) | |
self.field_value_names = map(list, value_names) | |
self.field_bitwidths = self._bitwidths() | |
self.field_symbols = list(self._symbols()) | |
self.symbols = list(itertools.chain.from_iterable(self.field_symbols)) | |
assert len(self.symbols) == sum(self.field_bitwidths) | |
# ceil(log_2(x)) | |
def _ceillg2(self, x): | |
return int(math.ceil(math.log(x) / math.log(2))) | |
# Returns a list corresponding to the number of bits used in the bitvector for each field | |
def _bitwidths(self): | |
if self.onehot: | |
return [(x if x != 1 else 0) for x in map(len, self.field_values)] | |
else: | |
return map(self._ceillg2, map(len, self.field_values)) | |
# returns a list of lists corresponding to the symbol names for each field | |
# These names are just for debugging | |
def _symbol_names(self): | |
total = 0 | |
for i, w in enumerate(self.field_bitwidths): | |
if w == 0: | |
continue | |
elif w == 1: | |
yield [self.field_names[i]] | |
else: | |
yield [self.field_names[i] + "_" + str(j) for j in reversed(range(w))] | |
def _symbols(self): | |
for x in self._symbol_names(): | |
yield map(sympy.Symbol, x) | |
# This is the same as field_values but pads to a all the possible values for the corresponding bitwidth | |
def _field_values_padded(self): | |
for v, b in zip(self.field_values, self.field_bitwidths): | |
if self.onehot: | |
if b == 0: | |
yield v | |
continue | |
q = v[:] | |
res = [] | |
for i in range(2 ** b): | |
if string.count(bin(i), "1") == 1: | |
res.append(q[0]) | |
q = q[1:] | |
else: | |
res.append(None) | |
yield res | |
else: | |
yield v + [None] * (2 ** b - len(v)) | |
# Returns a sympy expression that corresponds to TRUE IFF the bitmask matches these indexes | |
def _indexes_to_bitv(self, indexes): | |
assert len(indexes) == len(self.field_bitwidths) | |
res = [] | |
for i, l in zip(indexes, self.field_bitwidths): | |
if l == 0: continue | |
b = bin(i)[2:].zfill(l) | |
for bit in b: | |
res.append(int(bit)) | |
assert len(res) == len(self.symbols) | |
return res | |
# Takes some boolean predicate, that expects a list of values corresponding to the fields | |
# taking on that value. | |
def compute_predicate(self, predicate): | |
minterms = [] | |
dont_cares = [] | |
for option in itertools.product(*map(enumerate, self._field_values_padded())): | |
indexes = [x[0] for x in option] | |
values = [x[1] for x in option] | |
bitv = self._indexes_to_bitv(indexes) | |
# this entry in our table is undefined | |
if None in values: | |
dont_cares.append(bitv) | |
continue | |
names = [self.field_value_names[i][j] for i,j in enumerate(indexes)] | |
result = predicate(values, names) | |
if not result: continue | |
minterms.append(bitv) | |
logging.info("Symbols: {0}".format(self.symbols)) | |
logging.info("Minterms: {0}".format(minterms)) | |
logging.info("Dont cares: {0}".format(dont_cares)) | |
predicate = sympy.SOPform(self.symbols, minterms, dont_cares) | |
logging.info("Done running quine-mccluskey: {0!r}".format(predicate)) | |
return predicate | |
# Given a list containing None, True, False for each possible bit in bitwidths | |
# corresponding to the possible values for each FIELD that match that bitmask (None is unspecified, True or False bits must match) | |
# yield a list of len(FIELDS) such that each element contains: | |
# None if that field is unspecified | |
# the name of the field value if the field is specified | |
def _get_values_for(self, values): | |
possibilities = [] | |
x = values | |
for i, w in enumerate(self.field_bitwidths): | |
term = x[0:w] | |
x = x[w:] | |
# This can only happen when onehot == false, if a field has only one value | |
if w == 0: | |
possibilities.append([None]) | |
continue | |
bitmask = int("".join("1" if x in (True, False) else "0" for x in term),2) | |
# No opinions | |
if bitmask == 0 or w == 0: | |
possibilities.append([None]) | |
continue | |
bitv = int("".join("1" if x else "0" for x in term),2) | |
if self.onehot: | |
options = [z for j, z in enumerate(self.field_value_names[i]) if (1 << j) & bitmask == bitv] | |
else: | |
options = [z for j, z in enumerate(self.field_value_names[i]) if j & bitmask == bitv] | |
possibilities.append(options) | |
for res in itertools.product(*possibilities): | |
yield res | |
# a, b are both lists of either str or None specifiying an exact value or a wildcard | |
def _is_subset_of(self, a, b): | |
return all(x == y or x is None for (x, y) in zip(a, b)) | |
# Take a predicate in SOP form (e.g and OR of ANDs) | |
# and a list of the symbols in the predicate | |
# returns a table where the rows correspond to the symbols | |
# and the columns correspond to the AND clauses | |
# the values in each row correspond to the names of the matching FIELD values | |
# The table is returned in Latex format | |
# excluded_rows contains values from NAMES that correspond to rows to not include in the table | |
def predicate_table(self, v, excluded_rows=()): | |
v = sympy.to_dnf(v, simplify=True) | |
all_conjunctions = [] | |
if v == sympy.false: | |
ors = [] | |
all_conjunctions = [tuple(["-" for _ in self.field_names])] | |
elif v == sympy.true: | |
ors = [] | |
all_conjunctions = [tuple([None for _ in self.field_names])] | |
elif v.func == sympy.Or: | |
ors = v.args | |
else: | |
ors = [v] | |
for conjunction in ors: | |
if conjunction.func == sympy.And: | |
ands = conjunction.args | |
else: | |
ands = [conjunction] | |
unsatisfiable = False | |
values = [None] * sum(self.field_bitwidths) | |
for term in ands: | |
if term == sympy.true: | |
continue | |
elif term == sympy.false: | |
unsatisfiable = True | |
break | |
if term.func == sympy.Not: | |
idx = self.symbols.index(term.args[0]) | |
else: | |
idx = self.symbols.index(term) | |
assert values[idx] == None | |
values[idx] = term.func != sympy.Not | |
if not unsatisfiable: | |
all_conjunctions += list(self._get_values_for(values)) | |
all_conjunctions = list(sorted(sorted(set(all_conjunctions)), key=lambda x: -x.count(None))) | |
printed_rows = [] | |
for row in all_conjunctions: | |
if all(not self._is_subset_of(a, row) for a in printed_rows): | |
printed_rows.append(row) | |
rows = self.field_names | |
cols = printed_rows | |
table = [] | |
for i, r in enumerate(rows): | |
if r in excluded_rows: continue | |
vals = [x[i] for x in cols] | |
if all(x is None for x in vals): continue | |
vals = [x if x is not None else "*" for x in vals] | |
table.append([r] + vals) | |
x = numpy.array(table, numpy.object) | |
table = numpy.transpose(x).tolist() | |
res = tabulate.tabulate(table, tablefmt="latex_booktabs", headers="firstrow") | |
res = res.replace("*", r"\textit{*}") | |
return res | |
# Takes the symbols (from bittable_names) and an "assumption" and returns | |
# a Sympy predicate that corresponds to that assumption being true. | |
# Assumptions is a LIST with items that are | |
# tuple (NAME, VALUE) where NAME is a name in NAMES and VALUE is a possible value for that field | |
def assumption_to_predicate(self, predicate, assumption): | |
clauses = [] | |
for fieldname, fieldvalue in assumption: | |
field_index = self.field_names.index(fieldname) | |
field_possibility_index = self.field_value_names[field_index].index(fieldvalue) | |
symbols = self.field_symbols[field_index] | |
if self.onehot: | |
for i, sym in enumerate(symbols): | |
clauses.append(sym if i == field_possibility_index else sympy.Not(sym)) | |
else: | |
bits = bin(field_possibility_index)[2:].zfill(len(symbols)) | |
for sym, bit in zip(symbols, bits): | |
clauses.append(sym if bit == "1" else sympy.Not(sym)) | |
assumption = sympy.And(*clauses) | |
logging.info(assumption) | |
res = sympy.And(predicate, assumption) | |
res = sympy.to_dnf(res, simplify=True) | |
return res | |
def test(): | |
b = CombinatoricBitfield(["fielda", "fieldb"], [range(3), range(7)], [map(str, range(3)), map(str, range(7))]) | |
logging.info("Starting bitfield tests...") | |
z = b.compute_predicate(lambda x,_: (x[0] + x[1]) % 2 == 1) | |
logging.info(repr(z)) | |
for y in range(7): | |
logging.info("Assuming b=" + repr(y)) | |
logging.info("Full predicate=" +repr(b.assumption_to_predicate(z, [("fieldb", str(y))]))) | |
logging.info("Done testing") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment