Last active
April 29, 2024 00:10
-
-
Save worldbeater/da12e268babb0f1e088a7df1a98307a9 to your computer and use it in GitHub Desktop.
Simple term rewriting system (TRS) that is based on structural pattern matching, see https://peps.python.org/pep-0636 and https://inst.eecs.berkeley.edu/~cs294-260/sp24/2024-01-22-term-rewriting
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def topdown(rule, expr): | |
match rule(expr): | |
case (spec, *args): | |
return (spec, *(topdown(rule, arg) for arg in args)) | |
case expr: | |
return expr | |
def rewrite(rule, expr): | |
while (new_expr := topdown(rule, expr)) != expr: | |
expr = new_expr | |
return new_expr | |
def derive(expr): | |
match expr: | |
case ('∂', ('add', u, v), var): | |
return ('add', ('∂', u, var), ('∂', v, var)) | |
case ('∂', ('pow', var0, int(n)), var) if var == var0: | |
return ('mul', n, ('pow', var, n - 1)) | |
case ('∂', ('div', 1, var0), var) if var == var0: | |
return ('∂', ('pow', var, -1), var) | |
case expr: | |
return expr | |
def simplify(expr): | |
match expr: | |
case ('add', a, ('unm', b)) | ('add', ('unm', b), a): | |
return ('sub', a, b) | |
case ('mul', -1, var) | ('mul', var, -1): | |
return ('unm', var) | |
case ('pow', var, 1): | |
return var | |
case expr: | |
return expr | |
OPS = { | |
'add': '{%s}+{%s}', | |
'sub': '{%s}-{%s}', | |
'div': '\\frac{%s}{%s}', | |
'mul': '{%s}{%s}', | |
'pow': '{%s}^{%s}', | |
'∂': '\\frac{\\partial{\\left(%s\\right)}}{\\partial{%s}}', | |
} | |
def tex(tree): | |
match tree: | |
case (spec, *args): | |
return OPS[spec] % tuple(map(tex, args)) | |
case atom: | |
return atom | |
expr = ('∂', ('add', ('pow', 'x', 2), ('div', 1, 'x')), 'x') | |
derv = rewrite(derive, expr) | |
simp = rewrite(simplify, derv) | |
print(expr) # Original expression. | |
print(derv) # Derivative of the expression. | |
print(simp) # Simplified derivative of the expression. | |
print(tex(expr), '=', tex(simp)) # LaTeX markup. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment