Created
October 15, 2024 00:31
-
-
Save PWhiddy/5117ee0d96b3f37bead164723c76133c to your computer and use it in GitHub Desktop.
Symbolic Polynomial Derivative Calculator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import Enum | |
from random import random | |
Ops = Enum("Ops", ["Add", "Sub", "Mult", "Pow"]) | |
# 3 expression types for a simple AST | |
class Constant: | |
def __init__(self, value): | |
self.value = value | |
class Var: | |
pass | |
def lift_num_to_constant(value): | |
if isinstance(value, int) or isinstance(value, float): | |
return Constant(value) | |
else: | |
return value | |
class BinaryOp: | |
def __init__(self, left, op, right): | |
self.left = lift_num_to_constant(left) | |
self.right = lift_num_to_constant(right) | |
self.op = op | |
# display/print an AST | |
def display_expr(expr): | |
if isinstance(expr, Var): | |
return "x" | |
if isinstance(expr, Constant): | |
return str(expr.value) | |
if isinstance(expr, BinaryOp): | |
disp_left = display_expr(expr.left) | |
disp_right = display_expr(expr.right) | |
ops_to_displayed = {Ops.Add: "+", Ops.Sub: "-", Ops.Mult: "*", Ops.Pow: "^"} | |
disp_op = ops_to_displayed[expr.op] | |
return f"({disp_left} {disp_op} {disp_right})" | |
# evaluate an AST | |
def eval_expr(expr, var_value=1): | |
if isinstance(expr, Var): | |
return var_value | |
if isinstance(expr, Constant): | |
return expr.value | |
if isinstance(expr, BinaryOp): | |
evaled_left = eval_expr(expr.left, var_value) | |
evaled_right = eval_expr(expr.right, var_value) | |
if expr.op == Ops.Add: | |
return evaled_left + evaled_right | |
if expr.op == Ops.Sub: | |
return evaled_left - evaled_right | |
if expr.op == Ops.Mult: | |
return evaled_left * evaled_right | |
if expr.op == Ops.Pow: | |
return evaled_left ** evaled_right | |
# differentiate an AST with respect to Var | |
def dx_expr(expr): | |
if isinstance(expr, Var): | |
return Constant(1) | |
if isinstance(expr, Constant): | |
return Constant(0) | |
if isinstance(expr, BinaryOp): | |
diff_left = dx_expr(expr.left) | |
diff_right = dx_expr(expr.right) | |
if expr.op == Ops.Add or expr.op == Ops.Sub: | |
return BinaryOp(diff_left, expr.op, diff_right) | |
# product rule | |
if expr.op == Ops.Mult: | |
return BinaryOp(BinaryOp(diff_left, Ops.Mult, expr.right), | |
Ops.Add, BinaryOp(expr.left, Ops.Mult, diff_right)) | |
if expr.op == Ops.Pow: | |
if not isinstance(expr.right, Constant): | |
raise Exception("Non-constant exponent not supported!") | |
return BinaryOp(BinaryOp(expr.right, Ops.Mult, | |
BinaryOp(expr.left, Ops.Pow, BinaryOp(expr.right, Ops.Sub, 1))), Ops.Mult, diff_left) | |
# compute the derivate using finite differences with eval_expr | |
def finite_diff(expr, var_value): | |
h = 0.00001 | |
return (eval_expr(expr, var_value=var_value+h) - eval_expr(expr, var_value=var_value-h)) / (2 * h) | |
# test all the functions on an expression | |
def test_expr(expr): | |
print("display:") | |
print(display_expr(expr)) | |
test_val = 1 | |
print("eval for x=1:") | |
print(eval_expr(expr, var_value=test_val)) | |
print("d/dx") | |
print(display_expr(dx_expr(expr))) | |
test_deriv(expr) | |
print("\n") | |
# verify that both methods of calculating derivatives match within a tolerance | |
def test_deriv(expr): | |
rand_eval_value = (random()-0.5) * 100 # random between -50 and 50 | |
symbolic_diff_val = eval_expr(dx_expr(expr), rand_eval_value) | |
finite_diff_val = finite_diff(expr, rand_eval_value) | |
diff = abs(symbolic_diff_val - finite_diff_val) | |
is_match = diff < 0.001 | |
print(f"eval derivative at x={rand_eval_value:.1f} - symbolic: {symbolic_diff_val:.2f} " + | |
f"finite: {finite_diff_val:.2f}, diff: {diff:.2f}, match: {is_match}") | |
# test various expressions | |
# 3 * (11 - 15) | |
expr_a = BinaryOp(3, Ops.Mult, BinaryOp(11, Ops.Sub, 15)) | |
test_expr(expr_a) | |
# (3 ^ 10) * 0.1 | |
expr_b = BinaryOp(BinaryOp(3, Ops.Pow, 10), Ops.Mult, 0.1) | |
test_expr(expr_b) | |
# 4x^3 | |
expr_c = BinaryOp(4, Ops.Mult, BinaryOp(Var(), Ops.Pow, 3)) | |
test_expr(expr_c) | |
# (x^2)^2 | |
expr_d = BinaryOp(BinaryOp(Var(), Ops.Pow, 2), Ops.Pow, 2) | |
test_expr(expr_d) | |
# 3x^2 + 6x - 7 | |
expr_e = BinaryOp(BinaryOp(BinaryOp(3, Ops.Mult, | |
BinaryOp(Var(), Ops.Pow, 2)), Ops.Add, BinaryOp(6, Ops.Mult, Var())), Ops.Sub, 7) | |
test_expr(expr_e) | |
# ((x + 1)^2 + 1)^2 | |
expr_f = BinaryOp(BinaryOp(BinaryOp(BinaryOp(Var(), | |
Ops.Add, 1), Ops.Pow, 2), Ops.Add, 1), Ops.Pow, 2) | |
test_expr(expr_f) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment