Skip to content

Instantly share code, notes, and snippets.

@ntezak
Last active December 1, 2015 18:44
Show Gist options
  • Save ntezak/e1922acdd790e265963e to your computer and use it in GitHub Desktop.
Save ntezak/e1922acdd790e265963e to your computer and use it in GitHub Desktop.
Snippet to improve execution speed of lambdify by identifying common subexpressions
import sympy
from collections import defaultdict
def walk_exprs(exprs, subexpr=None):
"""
Return a dict {expr: frequency} for all subexpressions in the sequence of expressions `expr`.
"""
if subexpr is None:
subexpr = defaultdict(int)
for e in exprs:
if not isinstance(e, sympy.Atom):
subexpr[e] += 1
if len(e.args):
walk_exprs(e.args, subexpr=subexpr)
return subexpr
def my_lambdify(args, exprs, threshold=5, maxwrap=10, prefix=0, **kwargs):
"""
Create nested lambdas for reusing common subexpressions.
Arg spec is like sympy.lambdify, except for:
threshold: minimum frequency of subexpression to be reused.
maxwrap: maximal number of subexpressions to reuse.
prefix: used internally, can be any number.
"""
freqs = sorted(walk_exprs(exprs).items(), key=lambda (k,v): v, reverse=True)
freqs = freqs[:min(maxwrap, len(freqs))]
candidates = [k for (k,v) in freqs if v>=threshold]
if len(candidates):
l1 = sympy.lambdify(args, tuple(candidates), **kwargs)
subsymbs = sympy.symbols(", ".join(["my_dummy_{}_{}".format(k, prefix) for k,_ in enumerate(candidates)]))
if len(candidates) == 1:
subsymbs = (subsymbs,)
subs = {c: sym for (c,sym) in zip(candidates, subsymbs)}
nexprs = [expr.subs(subs) for expr in exprs]
nargs = tuple(args) + tuple(subsymbs)
l2 = my_lambdify(nargs, nexprs, threshold=threshold, maxwrap=maxwrap, prefix=prefix+1,**kwargs)
fn = lambda *args: l2(*(args+l1(*args)))
return fn
else:
return sympy.lambdify(args, exprs, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment