Skip to content

Instantly share code, notes, and snippets.

@nunoplopes
Last active January 2, 2023 18:56
Show Gist options
  • Select an option

  • Save nunoplopes/5789ee1c5766d04914802c040503dc0a to your computer and use it in GitHub Desktop.

Select an option

Save nunoplopes/5789ee1c5766d04914802c040503dc0a to your computer and use it in GitHub Desktop.
Prototyping dynamo guard simplification [don't use]
def mk_eq(lhs, rhs):
if (
isinstance(lhs, sympy.Add) and
isinstance(lhs.args[0], sympy.Integer) and
isinstance(rhs, sympy.Integer)
):
return mk_eq(lhs - lhs.args[0], rhs - lhs.args[0])
return sympy.Eq(lhs, rhs)
def simplify_eq(eq):
# a // b == 0 -> a < b
if isinstance(eq.lhs, FloorDiv) and eq.rhs == 0:
return sympy.Lt(eq.lhs.args[0], eq.lhs.args[1]).simplify()
return None
def track_symint(source, val):
...
# almost a Union find, but not quite
equivalences = {} # expr -> equivalences
comparisons = collections.defaultdict(list) # lhs -> comparison
for g, tb in self.guards:
if self._maybe_evaluate_static(g) is not None:
continue
g = self.simplify(g)
if isinstance(g, sympy.Eq):
if g.lhs in equivalences:
if g.rhs in equivalences:
equivalences[g.lhs] |= equivalences[g.rhs]
for e in equivalences[g.rhs]:
equivalences[e] = equivalences[g.lhs]
else:
equivalences[g.lhs].add(g.rhs)
equivalences[g.rhs] = equivalences[g.lhs]
else:
if g.rhs in equivalences:
equivalences[g.rhs].add(g.lhs)
equivalences[g.lhs] = equivalences[g.rhs]
else:
equivalences[g.lhs] = {g.lhs, g.rhs}
equivalences[g.rhs] = equivalences[g.lhs]
else:
assert(isinstance(g, sympy.Rel))
comparisons[g.lhs].append(g)
emitted_eqs = []
for expr, eqs in equivalences.items():
const = [v for v in eqs if isinstance(v, sympy.Integer)]
if len(const) != 1:
log.error(f"Guard is unsat\n")
raise Exception("Guard is unsat")
if expr != const[0]:
g = mk_eq(expr, const[0])
for lhs, rhs in emitted_eqs:
g = g.replace(lhs, rhs)
if g is not sympy.true:
simpl = simplify_eq(g)
if simpl is None:
exprs.append(ShapeGuardPrinter(symbol_to_source, source_ref).doprint(g))
else:
comparisons[simpl.lhs].append(simpl)
emitted_eqs.append((g.lhs, g.rhs))
for lhs, cs in comparisons.items():
# check if we have an equality over this expression already
if lhs in equivalences:
continue
Ne = [g for g in cs if isinstance(g, sympy.Ne)]
Gt = [g for g in cs if isinstance(g, (sympy.Ge, sympy.Gt)]
Lt = [g for g in cs if isinstance(g, (sympy.Le, sympy.Lt)]
emit = []
# TODO: filter out implied disequalities
for g in emit:
try:
exprs.append(ShapeGuardPrinter(symbol_to_source, source_ref).doprint(g))
except Exception:
log.warning(f"Failing guard\n")
raise
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment