Skip to content

Instantly share code, notes, and snippets.

@prophile
Last active June 29, 2024 04:13
Show Gist options
  • Save prophile/264e48e0b375b3d2534c654293d637b4 to your computer and use it in GitHub Desktop.
Save prophile/264e48e0b375b3d2534c654293d637b4 to your computer and use it in GitHub Desktop.
states = [
'Normal',
'Inbound',
'Unlocked',
]
transitions = [
('Normal', 'Inbound', 'B_enter & correct'),
('Inbound', 'Unlocked', 'timeout'),
('Unlocked', 'Normal', 'B_arrived'),
]
# 5-tuple of (name, mandatory_states, optional_states, formulation in terms of `state`, transients_ok)
outputs = [
('clear_code', ('Inbound',), (), 'state & B_enter', True),
('disp_occupied', ('Inbound', 'Unlocked'), (), 'state', True),
('show_code', ('Normal', 'Inbound'), (), 'state | B_arrived', True),
]
# Required sets: 2-tuples of lists of required states and don't care states
required_sets = [
]
UNIVERSALITY = True
####
import z3
import sys
import math
import tqdm
import pyeda.inter
import sympy
import string
import operator
import functools
# Disgusting pretty-print hack
import sympy.printing.pretty
for x in string.ascii_letters:
sys.modules['sympy.printing.pretty'].pretty_symbology.sub.pop(x, None)
min_bits = max(
int(math.ceil(math.log(len(states), 2))),
len(required_sets),
)
max_bits = len(states) + len(set(tuple(sorted([x, y])) for (x, y, _) in transitions)) - 1
print("Bits between {} and {}".format(min_bits, max_bits))
for bits in tqdm.trange(min_bits, max_bits + 1):
# Codes and masks
codes = [
z3.BitVec(f'code_{state}', bits)
for state in states
]
masks = [
z3.BitVec(f'mask_{state}', bits)
for state in states
]
s = z3.Solver()
# Constraint 1: state 0 must have code 0
s.add(codes[0] == 0)
# Constraint 2: all pairs of states must differ by at
# least one bit masked in both.
for state_left in range(len(states) - 1):
for state_right in range(state_left + 1, len(states)):
s.add(
(codes[state_left] ^ codes[state_right]) &
(masks[state_left] & masks[state_right])
!= 0
)
# Constraint 3: all transitions must have exactly one
# masked bit change.
for from_state, to_state, _ in transitions:
from_index = states.index(from_state)
to_index = states.index(to_state)
flips = masks[to_index] & (
codes[from_index] ^ codes[to_index]
)
# The population count of this vector must be 1.
# We can implement this with two checks:
# 1. The population count of the vector must not
# be zero. This means the vector itself cannot
# be zero.
# 2. No pairs of bits in the vector can both be
# set. This ensures the population count
# cannot be two or greater.
s.add(flips != 0)
for n_left in range(bits - 1):
for n_right in range(n_left + 1, bits):
s.add(
(
z3.Extract(n_left, n_left, flips) &
z3.Extract(n_right, n_right, flips)
) == 0
)
if UNIVERSALITY:
# Constraint 4: all possible codes must be covered.
# Quantifiers have ended up being slow in recent versions of Z3;
# given the codebooks are small, explicitly enumerate all possible codes
# instead.
for test_code in range(2 ** bits):
s.add(z3.Or([
(test_code & mask) == (code & mask)
for code, mask in zip(codes, masks)
]))
if required_sets:
# Constraint 5: have a bit assigned to each required set.
# Since the ordering of bits is not important, we can
# just assign the bits in order.
# Each required set is an OR of either having the RQ marked
# with a 0 or a 1.
# In a 1 case, the code must be 1 in every mandatory state
# and 0 in every non-optional state.
# In a 0 case, the code must be 0 in every mandatory state
# and 1 in every non-optional state.
for n, (required_states, optional_states) in enumerate(required_sets):
required_state_indices = [states.index(x) for x in required_states]
optional_state_indices = [states.index(x) for x in optional_states]
non_optional_state_indices = [
n
for n in range(len(states))
if n not in optional_state_indices
if n not in required_state_indices
]
bit_mask = 1 << n
positive_case = functools.reduce(
operator.and_,
[
codes[n]
for n in required_state_indices
] + [
~codes[n]
for n in non_optional_state_indices
],
bit_mask
) == bit_mask
negative_case = functools.reduce(
operator.and_,
[
~codes[n]
for n in required_state_indices
] + [
codes[n]
for n in non_optional_state_indices
],
bit_mask
) == bit_mask
s.add(z3.Or(positive_case, negative_case))
#print(s)
#exit(1)
result = s.check()
if result == z3.sat:
break
else:
print('No solution found.')
exit(1)
model = s.model()
high_by_bit = [[] for _ in range(bits)]
low_by_bit = [[] for _ in range(bits)]
dc_by_bit = [[] for _ in range(bits)]
unmasked_by_bit = [[] for _ in range(bits)]
masked_by_bit = [[] for _ in range(bits)]
for ix, state in enumerate(states):
code = model[codes[ix]].as_long()
mask = model[masks[ix]].as_long()
code_letters = [
{
(0, 0): 'z',
(0, 1): 'h',
(1, 0): '0',
(1, 1): '1',
}[(mask >> n) & 1, (code >> n) & 1]
for n in range(bits)
]
# A z or h can be replaced by x if:
# That bit is masked out here _and_ in every
# other state.
for n in range(bits):
do_not_care = False
if n >= len(required_sets) and not ((mask >> n) & 1):
# Bit is masked out; check if it is masked out in all other
# states we transition to.
#print(f"Checking mask on bit {n} in state {state}")
for from_state, to_state, _ in transitions:
if from_state != state:
continue
# Check if it's masked as relevant in this other state
if ((model[masks[states.index(to_state)]].as_long() >> n) & 1):
#print(f" -> Relevant for state {to_state}")
break
else:
# This is a true don't care
#print(" -> Don't care")
do_not_care = True
if ix == 0:
# We always want to pull predictably in the initial state
do_not_care = False
if do_not_care:
code_letters[n] = 'x'
dc_by_bit[n].append(state)
else:
if (code >> n) & 1:
high_by_bit[n].append(state)
else:
low_by_bit[n].append(state)
if (mask >> n) & 1:
masked_by_bit[n].append(state)
else:
unmasked_by_bit[n].append(state)
code_letters.reverse()
codeword = ''.join(code_letters)
print(f"State: {state}")
print(f" Code: {codeword}")
print()
state_variables = sympy.symbols([f'${n}' for n in range(bits)])
def _pyeda_ast2sympy(eda_vars, ast):
if ast[0] == 'lit':
if ast[1] > 0:
return state_variables[ast[1] - 1]
else:
return ~state_variables[-ast[1] - 1]
elif ast[0] == 'and':
return sympy.And(*[
_pyeda_ast2sympy(eda_vars, x)
for x in ast[1:]
])
elif ast[0] == 'or':
return sympy.Or(*[
_pyeda_ast2sympy(eda_vars, x)
for x in ast[1:]
])
elif ast[0] == 'const':
return ast[1]
elif ast[0] == 'not':
return ~_pyeda_ast2sympy(eda_vars, ast[1])
elif ast[0] == 'impl':
return ~_pyeda_ast2sympy(eda_vars, ast[1]) | _pyeda_ast2sympy(eda_vars, ast[2])
elif ast[0] == 'ite':
return sympy.ITE(
_pyeda_ast2sympy(eda_vars, ast[1]),
_pyeda_ast2sympy(eda_vars, ast[2]),
_pyeda_ast2sympy(eda_vars, ast[3])
)
elif ast[0] == 'xor':
return sympy.Xor(*[
_pyeda_ast2sympy(eda_vars, x)
for x in ast[1:]
])
elif ast[0] == 'eq':
if len(ast) != 3:
raise ValueError("Don't know how to handle the n-ary eq case")
return ~(_pyeda_ast2sympy(eda_vars, ast[1]) ^ _pyeda_ast2sympy(eda_vars, ast[2]))
def _pyeda2sympy(eda_vars, expr):
ast = expr.to_ast()
return _pyeda_ast2sympy(eda_vars, ast)
@functools.cache
def _match_state_canonical(mandatory, optional, transients_ok):
# TODO: Expand this with PyEDA
# We use the Espresso algorithm to minimise this directly from the truth table.
truth_table = []
mandatory_indices = {
ix
for ix, state in enumerate(states)
if state in mandatory
}
optional_indices = {
ix
for ix, state in enumerate(states)
if state in optional
}
for configuration in range(2 ** bits):
for ix, state in enumerate(states):
code = model[codes[ix]].as_long()
mask = model[masks[ix]].as_long()
is_match = (configuration & mask) == (code & mask)
if is_match:
stability_mask = 2 ** bits - 1
# Clear the stability mask for all bits that are dc
for bit_ix, dc_states in enumerate(dc_by_bit):
if state in dc_states:
stability_mask &= ~(1 << bit_ix)
is_stable = (configuration & stability_mask) == (code & stability_mask)
if ix in mandatory_indices:
# Mandatory case: drive the output if this is stable.
if is_stable or not transients_ok:
truth_table.append('1')
else:
truth_table.append('x')
elif ix in optional_indices:
# We don't care either way.
truth_table.append('x')
else:
# Excluded state. If transients are OK we can accept unstable
# configurations, in no case can we accept stable configurations.
if transients_ok and not is_stable:
truth_table.append('x')
else:
truth_table.append('0')
break
else:
# No state matched this configuration - it's an error state and the code is not universal
print(f"Cannot match config {configuration:b}")
assert not UNIVERSALITY
if transients_ok:
truth_table.append('x')
else:
truth_table.append('0')
assert len(truth_table) == 2 ** bits
eda_vars = pyeda.inter.ttvars('S', bits)
eda_tt = pyeda.inter.truthtable(eda_vars, ''.join(truth_table))
optimised_form, = pyeda.inter.espresso_tts(eda_tt)
return _pyeda2sympy(eda_vars, optimised_form)
def match_state(mandatory, optional=(), transients_ok=False):
result = _match_state_canonical(
tuple(sorted(set(mandatory))),
tuple(sorted(set(optional))),
bool(transients_ok),
)
if isinstance(result, int):
return bool(result)
return result
Q = sympy.symbols('Q')
for n in reversed(range(bits)):
print(f"Bit {n}")
print(f" High: {', '.join(high_by_bit[n])}")
print(f" Low: {', '.join(low_by_bit[n])}")
if dc_by_bit[n]:
print(f" Don't care: {', '.join(dc_by_bit[n])}")
if not masked_by_bit[n]:
print(f" NB: Not masked")
# Now we need to compute the set and clear conditions.
# We set a bit in the following cases:
# 1. The bit is irrelevant (by the mask), and we pull it to code.
# 2. The bit is relevant (by the mask), and a transition sets or
# clears it.
set_conditions = []
clear_conditions = []
# Enumerate the states for pull up and pull down.
pull_set_states = []
pull_clear_states = []
for ix, state in enumerate(states):
code = model[codes[ix]].as_long()
mask = model[masks[ix]].as_long()
if not ((mask >> n) & 1):
if state in dc_by_bit[n]:
# This is a true x, so we don't pull either way.
continue
# Bit is irrelevant; pull it to code.
if ((code >> n) & 1):
pull_set_states.append(state)
else:
pull_clear_states.append(state)
set_conditions.append(match_state(pull_set_states, optional=high_by_bit[n]))
clear_conditions.append(match_state(pull_clear_states, optional=low_by_bit[n]))
# Enumerate the transitions for hard set and clear.
for from_state, to_state, condition in transitions:
from_code = model[codes[states.index(from_state)]].as_long()
from_mask = model[masks[states.index(from_state)]].as_long()
to_code = model[codes[states.index(to_state)]].as_long()
to_mask = model[masks[states.index(to_state)]].as_long()
transition_bits = from_mask & to_mask & (from_code ^ to_code)
if (transition_bits >> n) & 1:
condition_parsed = sympy.parse_expr(condition)
if (to_code >> n) & 1:
condition_conditional = sympy.And(condition_parsed, match_state([from_state], optional=high_by_bit[n] + unmasked_by_bit[n]))
set_conditions.append(condition_conditional)
else:
condition_conditional = sympy.And(condition_parsed, match_state([from_state], optional=low_by_bit[n] + unmasked_by_bit[n]))
clear_conditions.append(condition_conditional)
command_set = sympy.simplify_logic(
sympy.Or(*set_conditions).subs(state_variables[n], False),
form='dnf',
force=True,
)
command_clear = sympy.simplify_logic(
sympy.Or(*clear_conditions).subs(state_variables[n], True),
form='dnf',
force=True,
)
print(" Set: ", end='')
sympy.pprint(command_set, use_unicode=True)
print(" Clear: ", end='')
sympy.pprint(command_clear, use_unicode=True)
all_symbols = command_set.free_symbols | command_clear.free_symbols
print(f" Control logic is {len(all_symbols)} bits")
# command_cycle = sympy.simplify_logic(
# command_set | (Q & ~command_clear),
# form='dnf',
# force=True,
# )
# print(" Feedback cycle: ", end='')
# sympy.pprint(command_cycle, use_unicode=True)
print()
for name, mandatory, optional, formulation, transients_ok in outputs:
expression = match_state(mandatory, optional, transients_ok)
formulation = sympy.parse_expr(formulation)
expression = formulation.subs('state', expression)
print(f"Output {name}")
print(f" Expression: ", end='')
sympy.pprint(expression, use_unicode=True)
all_symbols = expression.free_symbols
print(f" Output logic is {len(all_symbols)} bits")
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment