Skip to content

Instantly share code, notes, and snippets.

@kkroesch
Last active January 12, 2022 08:13
Show Gist options
  • Save kkroesch/6a5c068ee734a695d4528d24574e7c12 to your computer and use it in GitHub Desktop.
Save kkroesch/6a5c068ee734a695d4528d24574e7c12 to your computer and use it in GitHub Desktop.
Model algebraic expressions
from dataclasses import dataclass
from numbers import Number
from unittest import TestCase
class Expression:
def derive(self, variable):
return self
@dataclass
class BinaryExpression(Expression) :
left: Expression
right: Expression
@dataclass
class Const(Expression) :
value: Number
def __str__(self) -> str:
return str(self.value)
def evaluate(self, env={}) :
return self.value
def __add__(self, other) :
return self.value + other.value
def __mul__(self, other) :
return self.value * other.value
@dataclass
class Variable(Expression) :
name: str
def __str__(self) -> str:
return str(self.name)
def evaluate(self, env={}) :
return env[self.name]
class Times(BinaryExpression) :
def __str__(self) -> str:
return f"{self.left} * {self.right}"
def evaluate(self, env={}) :
return self.left.evaluate(env) * self.right.evaluate(env)
class Plus(BinaryExpression) :
def __str__(self) -> str:
return f"({self.left} + {self.right})"
def evaluate(self, env={}) :
return self.left.evaluate(env) + self.right.evaluate(env)
@dataclass
class Sum(Expression) :
summands: list
def __init__(self, *args) :
self.summands = (args)
def __str__(self) -> str:
return ' + '.join(map(str, self.summands))
def evaluate(self, env={}) :
return sum(self.summands)
class Exp(BinaryExpression) :
def __str__(self) -> str:
return f"{self.left}^{self.right}"
def evaluate(self, env={}) :
return self.left.evaluate(env) ** self.right.evaluate(env)
def derive(self, dimension=Variable('x')):
assert type(self.left) == Variable, "Left part must be variable."
return Times(self.right, Exp(self.left, self.right.value - 1))
class ExpressionTest(TestCase):
def test_constants(self):
c1 = Const(3)
c2 = Const(4)
assert 7 == c1 + c2
assert 12 == c1 * c2
def test_expression(self):
e1 = Times(Const(3), Plus(Variable('x'), Variable('y')))
assert '3 * (x + y)' == str(e1)
e2 = Plus(Times(Const(3), Variable('x')), Variable('y'))
assert '(3 * x + y)' == str(e2)
def test_evaluation(self):
env = { 'x': 3, 'y': 5 }
e1 = Times(Const(3), Plus(Variable('x'), Variable('y')))
assert 24 == e1.evaluate(env)
e2 = Exp(Const(2), Const(3))
assert 8 == e2.evaluate(env)
def test_polynom(self):
e1 = Sum(Times(Const(3), Exp(Variable('x'), Const(2))), Times(Const(2), Variable('y')), Const(3))
assert '3 * x^2 + 2 * y + 3' == str(e1)
def test_sum(self):
s = Sum(2,4,6)
assert '2 + 4 + 6' == str(s)
assert 12 == s.evaluate()
def test_derive(self):
s = Exp(Variable('x'), Const(4))
assert '4 * x^3' == str(s.derive())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment