Skip to content

Instantly share code, notes, and snippets.

@PWhiddy
Created October 15, 2024 00:31
Show Gist options
  • Save PWhiddy/5117ee0d96b3f37bead164723c76133c to your computer and use it in GitHub Desktop.
Save PWhiddy/5117ee0d96b3f37bead164723c76133c to your computer and use it in GitHub Desktop.
Symbolic Polynomial Derivative Calculator
from enum import Enum
from random import random
Ops = Enum("Ops", ["Add", "Sub", "Mult", "Pow"])
# 3 expression types for a simple AST
class Constant:
def __init__(self, value):
self.value = value
class Var:
pass
def lift_num_to_constant(value):
if isinstance(value, int) or isinstance(value, float):
return Constant(value)
else:
return value
class BinaryOp:
def __init__(self, left, op, right):
self.left = lift_num_to_constant(left)
self.right = lift_num_to_constant(right)
self.op = op
# display/print an AST
def display_expr(expr):
if isinstance(expr, Var):
return "x"
if isinstance(expr, Constant):
return str(expr.value)
if isinstance(expr, BinaryOp):
disp_left = display_expr(expr.left)
disp_right = display_expr(expr.right)
ops_to_displayed = {Ops.Add: "+", Ops.Sub: "-", Ops.Mult: "*", Ops.Pow: "^"}
disp_op = ops_to_displayed[expr.op]
return f"({disp_left} {disp_op} {disp_right})"
# evaluate an AST
def eval_expr(expr, var_value=1):
if isinstance(expr, Var):
return var_value
if isinstance(expr, Constant):
return expr.value
if isinstance(expr, BinaryOp):
evaled_left = eval_expr(expr.left, var_value)
evaled_right = eval_expr(expr.right, var_value)
if expr.op == Ops.Add:
return evaled_left + evaled_right
if expr.op == Ops.Sub:
return evaled_left - evaled_right
if expr.op == Ops.Mult:
return evaled_left * evaled_right
if expr.op == Ops.Pow:
return evaled_left ** evaled_right
# differentiate an AST with respect to Var
def dx_expr(expr):
if isinstance(expr, Var):
return Constant(1)
if isinstance(expr, Constant):
return Constant(0)
if isinstance(expr, BinaryOp):
diff_left = dx_expr(expr.left)
diff_right = dx_expr(expr.right)
if expr.op == Ops.Add or expr.op == Ops.Sub:
return BinaryOp(diff_left, expr.op, diff_right)
# product rule
if expr.op == Ops.Mult:
return BinaryOp(BinaryOp(diff_left, Ops.Mult, expr.right),
Ops.Add, BinaryOp(expr.left, Ops.Mult, diff_right))
if expr.op == Ops.Pow:
if not isinstance(expr.right, Constant):
raise Exception("Non-constant exponent not supported!")
return BinaryOp(BinaryOp(expr.right, Ops.Mult,
BinaryOp(expr.left, Ops.Pow, BinaryOp(expr.right, Ops.Sub, 1))), Ops.Mult, diff_left)
# compute the derivate using finite differences with eval_expr
def finite_diff(expr, var_value):
h = 0.00001
return (eval_expr(expr, var_value=var_value+h) - eval_expr(expr, var_value=var_value-h)) / (2 * h)
# test all the functions on an expression
def test_expr(expr):
print("display:")
print(display_expr(expr))
test_val = 1
print("eval for x=1:")
print(eval_expr(expr, var_value=test_val))
print("d/dx")
print(display_expr(dx_expr(expr)))
test_deriv(expr)
print("\n")
# verify that both methods of calculating derivatives match within a tolerance
def test_deriv(expr):
rand_eval_value = (random()-0.5) * 100 # random between -50 and 50
symbolic_diff_val = eval_expr(dx_expr(expr), rand_eval_value)
finite_diff_val = finite_diff(expr, rand_eval_value)
diff = abs(symbolic_diff_val - finite_diff_val)
is_match = diff < 0.001
print(f"eval derivative at x={rand_eval_value:.1f} - symbolic: {symbolic_diff_val:.2f} " +
f"finite: {finite_diff_val:.2f}, diff: {diff:.2f}, match: {is_match}")
# test various expressions
# 3 * (11 - 15)
expr_a = BinaryOp(3, Ops.Mult, BinaryOp(11, Ops.Sub, 15))
test_expr(expr_a)
# (3 ^ 10) * 0.1
expr_b = BinaryOp(BinaryOp(3, Ops.Pow, 10), Ops.Mult, 0.1)
test_expr(expr_b)
# 4x^3
expr_c = BinaryOp(4, Ops.Mult, BinaryOp(Var(), Ops.Pow, 3))
test_expr(expr_c)
# (x^2)^2
expr_d = BinaryOp(BinaryOp(Var(), Ops.Pow, 2), Ops.Pow, 2)
test_expr(expr_d)
# 3x^2 + 6x - 7
expr_e = BinaryOp(BinaryOp(BinaryOp(3, Ops.Mult,
BinaryOp(Var(), Ops.Pow, 2)), Ops.Add, BinaryOp(6, Ops.Mult, Var())), Ops.Sub, 7)
test_expr(expr_e)
# ((x + 1)^2 + 1)^2
expr_f = BinaryOp(BinaryOp(BinaryOp(BinaryOp(Var(),
Ops.Add, 1), Ops.Pow, 2), Ops.Add, 1), Ops.Pow, 2)
test_expr(expr_f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment