Last active
June 29, 2024 04:13
-
-
Save prophile/264e48e0b375b3d2534c654293d637b4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
states = [ | |
'Normal', | |
'Inbound', | |
'Unlocked', | |
] | |
transitions = [ | |
('Normal', 'Inbound', 'B_enter & correct'), | |
('Inbound', 'Unlocked', 'timeout'), | |
('Unlocked', 'Normal', 'B_arrived'), | |
] | |
# 5-tuple of (name, mandatory_states, optional_states, formulation in terms of `state`, transients_ok) | |
outputs = [ | |
('clear_code', ('Inbound',), (), 'state & B_enter', True), | |
('disp_occupied', ('Inbound', 'Unlocked'), (), 'state', True), | |
('show_code', ('Normal', 'Inbound'), (), 'state | B_arrived', True), | |
] | |
# Required sets: 2-tuples of lists of required states and don't care states | |
required_sets = [ | |
] | |
UNIVERSALITY = True | |
#### | |
import z3 | |
import sys | |
import math | |
import tqdm | |
import pyeda.inter | |
import sympy | |
import string | |
import operator | |
import functools | |
# Disgusting pretty-print hack | |
import sympy.printing.pretty | |
for x in string.ascii_letters: | |
sys.modules['sympy.printing.pretty'].pretty_symbology.sub.pop(x, None) | |
min_bits = max( | |
int(math.ceil(math.log(len(states), 2))), | |
len(required_sets), | |
) | |
max_bits = len(states) + len(set(tuple(sorted([x, y])) for (x, y, _) in transitions)) - 1 | |
print("Bits between {} and {}".format(min_bits, max_bits)) | |
for bits in tqdm.trange(min_bits, max_bits + 1): | |
# Codes and masks | |
codes = [ | |
z3.BitVec(f'code_{state}', bits) | |
for state in states | |
] | |
masks = [ | |
z3.BitVec(f'mask_{state}', bits) | |
for state in states | |
] | |
s = z3.Solver() | |
# Constraint 1: state 0 must have code 0 | |
s.add(codes[0] == 0) | |
# Constraint 2: all pairs of states must differ by at | |
# least one bit masked in both. | |
for state_left in range(len(states) - 1): | |
for state_right in range(state_left + 1, len(states)): | |
s.add( | |
(codes[state_left] ^ codes[state_right]) & | |
(masks[state_left] & masks[state_right]) | |
!= 0 | |
) | |
# Constraint 3: all transitions must have exactly one | |
# masked bit change. | |
for from_state, to_state, _ in transitions: | |
from_index = states.index(from_state) | |
to_index = states.index(to_state) | |
flips = masks[to_index] & ( | |
codes[from_index] ^ codes[to_index] | |
) | |
# The population count of this vector must be 1. | |
# We can implement this with two checks: | |
# 1. The population count of the vector must not | |
# be zero. This means the vector itself cannot | |
# be zero. | |
# 2. No pairs of bits in the vector can both be | |
# set. This ensures the population count | |
# cannot be two or greater. | |
s.add(flips != 0) | |
for n_left in range(bits - 1): | |
for n_right in range(n_left + 1, bits): | |
s.add( | |
( | |
z3.Extract(n_left, n_left, flips) & | |
z3.Extract(n_right, n_right, flips) | |
) == 0 | |
) | |
if UNIVERSALITY: | |
# Constraint 4: all possible codes must be covered. | |
# Quantifiers have ended up being slow in recent versions of Z3; | |
# given the codebooks are small, explicitly enumerate all possible codes | |
# instead. | |
for test_code in range(2 ** bits): | |
s.add(z3.Or([ | |
(test_code & mask) == (code & mask) | |
for code, mask in zip(codes, masks) | |
])) | |
if required_sets: | |
# Constraint 5: have a bit assigned to each required set. | |
# Since the ordering of bits is not important, we can | |
# just assign the bits in order. | |
# Each required set is an OR of either having the RQ marked | |
# with a 0 or a 1. | |
# In a 1 case, the code must be 1 in every mandatory state | |
# and 0 in every non-optional state. | |
# In a 0 case, the code must be 0 in every mandatory state | |
# and 1 in every non-optional state. | |
for n, (required_states, optional_states) in enumerate(required_sets): | |
required_state_indices = [states.index(x) for x in required_states] | |
optional_state_indices = [states.index(x) for x in optional_states] | |
non_optional_state_indices = [ | |
n | |
for n in range(len(states)) | |
if n not in optional_state_indices | |
if n not in required_state_indices | |
] | |
bit_mask = 1 << n | |
positive_case = functools.reduce( | |
operator.and_, | |
[ | |
codes[n] | |
for n in required_state_indices | |
] + [ | |
~codes[n] | |
for n in non_optional_state_indices | |
], | |
bit_mask | |
) == bit_mask | |
negative_case = functools.reduce( | |
operator.and_, | |
[ | |
~codes[n] | |
for n in required_state_indices | |
] + [ | |
codes[n] | |
for n in non_optional_state_indices | |
], | |
bit_mask | |
) == bit_mask | |
s.add(z3.Or(positive_case, negative_case)) | |
#print(s) | |
#exit(1) | |
result = s.check() | |
if result == z3.sat: | |
break | |
else: | |
print('No solution found.') | |
exit(1) | |
model = s.model() | |
high_by_bit = [[] for _ in range(bits)] | |
low_by_bit = [[] for _ in range(bits)] | |
dc_by_bit = [[] for _ in range(bits)] | |
unmasked_by_bit = [[] for _ in range(bits)] | |
masked_by_bit = [[] for _ in range(bits)] | |
for ix, state in enumerate(states): | |
code = model[codes[ix]].as_long() | |
mask = model[masks[ix]].as_long() | |
code_letters = [ | |
{ | |
(0, 0): 'z', | |
(0, 1): 'h', | |
(1, 0): '0', | |
(1, 1): '1', | |
}[(mask >> n) & 1, (code >> n) & 1] | |
for n in range(bits) | |
] | |
# A z or h can be replaced by x if: | |
# That bit is masked out here _and_ in every | |
# other state. | |
for n in range(bits): | |
do_not_care = False | |
if n >= len(required_sets) and not ((mask >> n) & 1): | |
# Bit is masked out; check if it is masked out in all other | |
# states we transition to. | |
#print(f"Checking mask on bit {n} in state {state}") | |
for from_state, to_state, _ in transitions: | |
if from_state != state: | |
continue | |
# Check if it's masked as relevant in this other state | |
if ((model[masks[states.index(to_state)]].as_long() >> n) & 1): | |
#print(f" -> Relevant for state {to_state}") | |
break | |
else: | |
# This is a true don't care | |
#print(" -> Don't care") | |
do_not_care = True | |
if ix == 0: | |
# We always want to pull predictably in the initial state | |
do_not_care = False | |
if do_not_care: | |
code_letters[n] = 'x' | |
dc_by_bit[n].append(state) | |
else: | |
if (code >> n) & 1: | |
high_by_bit[n].append(state) | |
else: | |
low_by_bit[n].append(state) | |
if (mask >> n) & 1: | |
masked_by_bit[n].append(state) | |
else: | |
unmasked_by_bit[n].append(state) | |
code_letters.reverse() | |
codeword = ''.join(code_letters) | |
print(f"State: {state}") | |
print(f" Code: {codeword}") | |
print() | |
state_variables = sympy.symbols([f'${n}' for n in range(bits)]) | |
def _pyeda_ast2sympy(eda_vars, ast): | |
if ast[0] == 'lit': | |
if ast[1] > 0: | |
return state_variables[ast[1] - 1] | |
else: | |
return ~state_variables[-ast[1] - 1] | |
elif ast[0] == 'and': | |
return sympy.And(*[ | |
_pyeda_ast2sympy(eda_vars, x) | |
for x in ast[1:] | |
]) | |
elif ast[0] == 'or': | |
return sympy.Or(*[ | |
_pyeda_ast2sympy(eda_vars, x) | |
for x in ast[1:] | |
]) | |
elif ast[0] == 'const': | |
return ast[1] | |
elif ast[0] == 'not': | |
return ~_pyeda_ast2sympy(eda_vars, ast[1]) | |
elif ast[0] == 'impl': | |
return ~_pyeda_ast2sympy(eda_vars, ast[1]) | _pyeda_ast2sympy(eda_vars, ast[2]) | |
elif ast[0] == 'ite': | |
return sympy.ITE( | |
_pyeda_ast2sympy(eda_vars, ast[1]), | |
_pyeda_ast2sympy(eda_vars, ast[2]), | |
_pyeda_ast2sympy(eda_vars, ast[3]) | |
) | |
elif ast[0] == 'xor': | |
return sympy.Xor(*[ | |
_pyeda_ast2sympy(eda_vars, x) | |
for x in ast[1:] | |
]) | |
elif ast[0] == 'eq': | |
if len(ast) != 3: | |
raise ValueError("Don't know how to handle the n-ary eq case") | |
return ~(_pyeda_ast2sympy(eda_vars, ast[1]) ^ _pyeda_ast2sympy(eda_vars, ast[2])) | |
def _pyeda2sympy(eda_vars, expr): | |
ast = expr.to_ast() | |
return _pyeda_ast2sympy(eda_vars, ast) | |
@functools.cache | |
def _match_state_canonical(mandatory, optional, transients_ok): | |
# TODO: Expand this with PyEDA | |
# We use the Espresso algorithm to minimise this directly from the truth table. | |
truth_table = [] | |
mandatory_indices = { | |
ix | |
for ix, state in enumerate(states) | |
if state in mandatory | |
} | |
optional_indices = { | |
ix | |
for ix, state in enumerate(states) | |
if state in optional | |
} | |
for configuration in range(2 ** bits): | |
for ix, state in enumerate(states): | |
code = model[codes[ix]].as_long() | |
mask = model[masks[ix]].as_long() | |
is_match = (configuration & mask) == (code & mask) | |
if is_match: | |
stability_mask = 2 ** bits - 1 | |
# Clear the stability mask for all bits that are dc | |
for bit_ix, dc_states in enumerate(dc_by_bit): | |
if state in dc_states: | |
stability_mask &= ~(1 << bit_ix) | |
is_stable = (configuration & stability_mask) == (code & stability_mask) | |
if ix in mandatory_indices: | |
# Mandatory case: drive the output if this is stable. | |
if is_stable or not transients_ok: | |
truth_table.append('1') | |
else: | |
truth_table.append('x') | |
elif ix in optional_indices: | |
# We don't care either way. | |
truth_table.append('x') | |
else: | |
# Excluded state. If transients are OK we can accept unstable | |
# configurations, in no case can we accept stable configurations. | |
if transients_ok and not is_stable: | |
truth_table.append('x') | |
else: | |
truth_table.append('0') | |
break | |
else: | |
# No state matched this configuration - it's an error state and the code is not universal | |
print(f"Cannot match config {configuration:b}") | |
assert not UNIVERSALITY | |
if transients_ok: | |
truth_table.append('x') | |
else: | |
truth_table.append('0') | |
assert len(truth_table) == 2 ** bits | |
eda_vars = pyeda.inter.ttvars('S', bits) | |
eda_tt = pyeda.inter.truthtable(eda_vars, ''.join(truth_table)) | |
optimised_form, = pyeda.inter.espresso_tts(eda_tt) | |
return _pyeda2sympy(eda_vars, optimised_form) | |
def match_state(mandatory, optional=(), transients_ok=False): | |
result = _match_state_canonical( | |
tuple(sorted(set(mandatory))), | |
tuple(sorted(set(optional))), | |
bool(transients_ok), | |
) | |
if isinstance(result, int): | |
return bool(result) | |
return result | |
Q = sympy.symbols('Q') | |
for n in reversed(range(bits)): | |
print(f"Bit {n}") | |
print(f" High: {', '.join(high_by_bit[n])}") | |
print(f" Low: {', '.join(low_by_bit[n])}") | |
if dc_by_bit[n]: | |
print(f" Don't care: {', '.join(dc_by_bit[n])}") | |
if not masked_by_bit[n]: | |
print(f" NB: Not masked") | |
# Now we need to compute the set and clear conditions. | |
# We set a bit in the following cases: | |
# 1. The bit is irrelevant (by the mask), and we pull it to code. | |
# 2. The bit is relevant (by the mask), and a transition sets or | |
# clears it. | |
set_conditions = [] | |
clear_conditions = [] | |
# Enumerate the states for pull up and pull down. | |
pull_set_states = [] | |
pull_clear_states = [] | |
for ix, state in enumerate(states): | |
code = model[codes[ix]].as_long() | |
mask = model[masks[ix]].as_long() | |
if not ((mask >> n) & 1): | |
if state in dc_by_bit[n]: | |
# This is a true x, so we don't pull either way. | |
continue | |
# Bit is irrelevant; pull it to code. | |
if ((code >> n) & 1): | |
pull_set_states.append(state) | |
else: | |
pull_clear_states.append(state) | |
set_conditions.append(match_state(pull_set_states, optional=high_by_bit[n])) | |
clear_conditions.append(match_state(pull_clear_states, optional=low_by_bit[n])) | |
# Enumerate the transitions for hard set and clear. | |
for from_state, to_state, condition in transitions: | |
from_code = model[codes[states.index(from_state)]].as_long() | |
from_mask = model[masks[states.index(from_state)]].as_long() | |
to_code = model[codes[states.index(to_state)]].as_long() | |
to_mask = model[masks[states.index(to_state)]].as_long() | |
transition_bits = from_mask & to_mask & (from_code ^ to_code) | |
if (transition_bits >> n) & 1: | |
condition_parsed = sympy.parse_expr(condition) | |
if (to_code >> n) & 1: | |
condition_conditional = sympy.And(condition_parsed, match_state([from_state], optional=high_by_bit[n] + unmasked_by_bit[n])) | |
set_conditions.append(condition_conditional) | |
else: | |
condition_conditional = sympy.And(condition_parsed, match_state([from_state], optional=low_by_bit[n] + unmasked_by_bit[n])) | |
clear_conditions.append(condition_conditional) | |
command_set = sympy.simplify_logic( | |
sympy.Or(*set_conditions).subs(state_variables[n], False), | |
form='dnf', | |
force=True, | |
) | |
command_clear = sympy.simplify_logic( | |
sympy.Or(*clear_conditions).subs(state_variables[n], True), | |
form='dnf', | |
force=True, | |
) | |
print(" Set: ", end='') | |
sympy.pprint(command_set, use_unicode=True) | |
print(" Clear: ", end='') | |
sympy.pprint(command_clear, use_unicode=True) | |
all_symbols = command_set.free_symbols | command_clear.free_symbols | |
print(f" Control logic is {len(all_symbols)} bits") | |
# command_cycle = sympy.simplify_logic( | |
# command_set | (Q & ~command_clear), | |
# form='dnf', | |
# force=True, | |
# ) | |
# print(" Feedback cycle: ", end='') | |
# sympy.pprint(command_cycle, use_unicode=True) | |
print() | |
for name, mandatory, optional, formulation, transients_ok in outputs: | |
expression = match_state(mandatory, optional, transients_ok) | |
formulation = sympy.parse_expr(formulation) | |
expression = formulation.subs('state', expression) | |
print(f"Output {name}") | |
print(f" Expression: ", end='') | |
sympy.pprint(expression, use_unicode=True) | |
all_symbols = expression.free_symbols | |
print(f" Output logic is {len(all_symbols)} bits") | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment