Last active
January 12, 2022 08:13
-
-
Save kkroesch/6a5c068ee734a695d4528d24574e7c12 to your computer and use it in GitHub Desktop.
Model algebraic expressions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from numbers import Number | |
from unittest import TestCase | |
class Expression: | |
def derive(self, variable): | |
return self | |
@dataclass | |
class BinaryExpression(Expression) : | |
left: Expression | |
right: Expression | |
@dataclass | |
class Const(Expression) : | |
value: Number | |
def __str__(self) -> str: | |
return str(self.value) | |
def evaluate(self, env={}) : | |
return self.value | |
def __add__(self, other) : | |
return self.value + other.value | |
def __mul__(self, other) : | |
return self.value * other.value | |
@dataclass | |
class Variable(Expression) : | |
name: str | |
def __str__(self) -> str: | |
return str(self.name) | |
def evaluate(self, env={}) : | |
return env[self.name] | |
class Times(BinaryExpression) : | |
def __str__(self) -> str: | |
return f"{self.left} * {self.right}" | |
def evaluate(self, env={}) : | |
return self.left.evaluate(env) * self.right.evaluate(env) | |
class Plus(BinaryExpression) : | |
def __str__(self) -> str: | |
return f"({self.left} + {self.right})" | |
def evaluate(self, env={}) : | |
return self.left.evaluate(env) + self.right.evaluate(env) | |
@dataclass | |
class Sum(Expression) : | |
summands: list | |
def __init__(self, *args) : | |
self.summands = (args) | |
def __str__(self) -> str: | |
return ' + '.join(map(str, self.summands)) | |
def evaluate(self, env={}) : | |
return sum(self.summands) | |
class Exp(BinaryExpression) : | |
def __str__(self) -> str: | |
return f"{self.left}^{self.right}" | |
def evaluate(self, env={}) : | |
return self.left.evaluate(env) ** self.right.evaluate(env) | |
def derive(self, dimension=Variable('x')): | |
assert type(self.left) == Variable, "Left part must be variable." | |
return Times(self.right, Exp(self.left, self.right.value - 1)) | |
class ExpressionTest(TestCase): | |
def test_constants(self): | |
c1 = Const(3) | |
c2 = Const(4) | |
assert 7 == c1 + c2 | |
assert 12 == c1 * c2 | |
def test_expression(self): | |
e1 = Times(Const(3), Plus(Variable('x'), Variable('y'))) | |
assert '3 * (x + y)' == str(e1) | |
e2 = Plus(Times(Const(3), Variable('x')), Variable('y')) | |
assert '(3 * x + y)' == str(e2) | |
def test_evaluation(self): | |
env = { 'x': 3, 'y': 5 } | |
e1 = Times(Const(3), Plus(Variable('x'), Variable('y'))) | |
assert 24 == e1.evaluate(env) | |
e2 = Exp(Const(2), Const(3)) | |
assert 8 == e2.evaluate(env) | |
def test_polynom(self): | |
e1 = Sum(Times(Const(3), Exp(Variable('x'), Const(2))), Times(Const(2), Variable('y')), Const(3)) | |
assert '3 * x^2 + 2 * y + 3' == str(e1) | |
def test_sum(self): | |
s = Sum(2,4,6) | |
assert '2 + 4 + 6' == str(s) | |
assert 12 == s.evaluate() | |
def test_derive(self): | |
s = Exp(Variable('x'), Const(4)) | |
assert '4 * x^3' == str(s.derive()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment