Created
October 17, 2024 18:51
-
-
Save PWhiddy/a7e87ef944c39564db34f5e9e0f1486f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import Enum | |
from random import random | |
Ops = Enum("Ops", ["Add", "Sub", "Mult", "Pow"]) | |
# 3 expression types for a simple AST | |
class Constant: | |
def __init__(self, value): | |
self.value = value | |
class Var: | |
pass | |
def lift_num_to_constant(value): | |
if isinstance(value, int) or isinstance(value, float): | |
return Constant(value) | |
else: | |
return value | |
class BinaryOp: | |
def __init__(self, left, op, right): | |
self.left = lift_num_to_constant(left) | |
self.right = lift_num_to_constant(right) | |
self.op = op | |
# display/print an AST | |
def display_expr(expr): | |
if isinstance(expr, Var): | |
return "x" | |
if isinstance(expr, Constant): | |
return str(expr.value) | |
if isinstance(expr, BinaryOp): | |
disp_left = display_expr(expr.left) | |
disp_right = display_expr(expr.right) | |
ops_to_displayed = {Ops.Add: "+", Ops.Sub: "-", Ops.Mult: "*", Ops.Pow: "^"} | |
disp_op = ops_to_displayed[expr.op] | |
return f"({disp_left} {disp_op} {disp_right})" | |
# evaluate an AST | |
def eval_expr(expr, var_value=1): | |
if isinstance(expr, Var): | |
return var_value | |
if isinstance(expr, Constant): | |
return expr.value | |
if isinstance(expr, BinaryOp): | |
evaled_left = eval_expr(expr.left, var_value) | |
evaled_right = eval_expr(expr.right, var_value) | |
if expr.op == Ops.Add: | |
return evaled_left + evaled_right | |
if expr.op == Ops.Sub: | |
return evaled_left - evaled_right | |
if expr.op == Ops.Mult: | |
return evaled_left * evaled_right | |
if expr.op == Ops.Pow: | |
return evaled_left ** evaled_right | |
def is_one(expr): | |
if isinstance(expr, Constant): | |
if expr.value == 1: | |
return True | |
return False | |
def is_zero(expr): | |
if isinstance(expr, Constant): | |
if expr.value == 0: | |
return True | |
return False | |
def simplify_expr(expr): | |
if isinstance(expr, Var) or isinstance(expr, Constant): | |
return expr | |
if isinstance(expr, BinaryOp): | |
simpl_left = simplify_expr(expr.left) | |
simpl_right = simplify_expr(expr.right) | |
if isinstance(simpl_left, Constant) and isinstance(simpl_right, Constant): | |
return Constant(eval_expr(BinaryOp(simpl_left, expr.op, simpl_right))) | |
if expr.op == Ops.Add or expr.op == Ops.Sub: | |
if is_zero(simpl_left): | |
return simpl_right | |
if is_zero(simpl_right): | |
return simpl_left | |
if expr.op == Ops.Mult: | |
if is_one(simpl_left): | |
return simpl_right | |
if is_one(simpl_right): | |
return simpl_left | |
if is_zero(simpl_left) or is_zero(simpl_right): | |
return Constant(0) | |
if expr.op == Ops.Pow: | |
if is_one(simpl_right): | |
return simpl_left | |
if is_zero(simpl_right): | |
return Constant(1) | |
return BinaryOp(simpl_left, expr.op, simpl_right) | |
# differentiate an AST with respect to Var | |
def dx_expr(expr): | |
if isinstance(expr, Var): | |
return Constant(1) | |
if isinstance(expr, Constant): | |
return Constant(0) | |
if isinstance(expr, BinaryOp): | |
diff_left = dx_expr(expr.left) | |
diff_right = dx_expr(expr.right) | |
if expr.op == Ops.Add or expr.op == Ops.Sub: | |
return BinaryOp(diff_left, expr.op, diff_right) | |
# product rule | |
if expr.op == Ops.Mult: | |
return BinaryOp(BinaryOp(diff_left, Ops.Mult, expr.right), | |
Ops.Add, BinaryOp(expr.left, Ops.Mult, diff_right)) | |
if expr.op == Ops.Pow: | |
if not isinstance(expr.right, Constant): | |
raise Exception("Non-constant exponent not supported!") | |
return BinaryOp(BinaryOp(expr.right, Ops.Mult, | |
BinaryOp(expr.left, Ops.Pow, BinaryOp(expr.right, Ops.Sub, 1))), Ops.Mult, diff_left) | |
# compute the derivate using finite differences with eval_expr | |
def finite_diff(expr, var_value): | |
h = 0.00001 | |
return (eval_expr(expr, var_value=var_value+h) - eval_expr(expr, var_value=var_value-h)) / (2 * h) | |
# test all the functions on an expression | |
def test_expr(expr): | |
print("display:") | |
print(display_expr(expr)) | |
test_val = 1 | |
print("eval for x=1:") | |
print(eval_expr(expr, var_value=test_val)) | |
print("d/dx") | |
print(display_expr(simplify_expr(dx_expr(expr)))) | |
test_deriv(expr) | |
print("\n") | |
# verify that both methods of calculating derivatives match within a tolerance | |
def test_deriv(expr): | |
rand_eval_value = (random()-0.5) * 100 # random between -50 and 50 | |
symbolic_diff_val = eval_expr(dx_expr(expr), rand_eval_value) | |
finite_diff_val = finite_diff(expr, rand_eval_value) | |
diff = abs(symbolic_diff_val - finite_diff_val) | |
is_match = diff < 0.001 | |
print(f"eval derivative at x={rand_eval_value:.1f} - symbolic: {symbolic_diff_val:.2f} " + | |
f"finite: {finite_diff_val:.2f}, diff: {diff:.2f}, match: {is_match}") | |
# test various expressions | |
# 3 * (11 - 15) | |
expr_a = BinaryOp(3, Ops.Mult, BinaryOp(11, Ops.Sub, 15)) | |
test_expr(expr_a) | |
# (3 ^ 10) * 0.1 | |
expr_b = BinaryOp(BinaryOp(3, Ops.Pow, 10), Ops.Mult, 0.1) | |
test_expr(expr_b) | |
# 4x^3 | |
expr_c = BinaryOp(4, Ops.Mult, BinaryOp(Var(), Ops.Pow, 3)) | |
test_expr(expr_c) | |
# (x^2)^2 | |
expr_d = BinaryOp(BinaryOp(Var(), Ops.Pow, 2), Ops.Pow, 2) | |
test_expr(expr_d) | |
# 3x^2 + 6x - 7 | |
expr_e = BinaryOp(BinaryOp(BinaryOp(3, Ops.Mult, | |
BinaryOp(Var(), Ops.Pow, 2)), Ops.Add, BinaryOp(6, Ops.Mult, Var())), Ops.Sub, 7) | |
test_expr(expr_e) | |
# ((x + 1)^2 + 1)^2 | |
expr_f = BinaryOp(BinaryOp(BinaryOp(BinaryOp(Var(), | |
Ops.Add, 1), Ops.Pow, 2), Ops.Add, 1), Ops.Pow, 2) | |
test_expr(expr_f) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment