Last active
September 4, 2024 16:56
-
-
Save flomnes/68dfc8dfd57cf38317f116dd222eca87 to your computer and use it in GitHub Desktop.
AST Simplification
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <optional> | |
#include <string> | |
#include <utility> | |
class Node | |
{ | |
public: | |
virtual ~Node() = default; | |
}; | |
class Literal: public Node | |
{ | |
public: | |
Literal(double v): | |
value(v) | |
{ | |
} | |
double value; | |
}; | |
class Parameter: public Node | |
{ | |
public: | |
Parameter(const std::string& v): | |
value(v) | |
{ | |
} | |
std::string value; | |
}; | |
class AddNode: public std::pair<Node*, Node*>, public Node | |
{ | |
public: | |
using std::pair<Node*, Node*>::pair; | |
}; | |
class MultNode: public std::pair<Node*, Node*>, public Node | |
{ | |
public: | |
using std::pair<Node*, Node*>::pair; | |
}; | |
namespace Simplifications | |
{ | |
constexpr unsigned int addTwoLiterals = 1UL << 0; | |
constexpr unsigned int addToZero = 1UL << 1; | |
constexpr unsigned int factorExpressions = 1UL << 2; | |
constexpr unsigned int all = addTwoLiterals | addToZero | factorExpressions; | |
}; // namespace Simplifications | |
std::optional<std::pair<double, Node*>> getLiteral(MultNode* n) | |
{ | |
auto* x1 = dynamic_cast<Literal*>(n->first); | |
auto* x2 = dynamic_cast<Literal*>(n->second); | |
if (x1) | |
{ | |
return {{x1->value, n->second}}; | |
} | |
if (x2) | |
{ | |
return {{x2->value, n->first}}; | |
} | |
return {}; | |
} | |
// return 'false' if nodes compare equal | |
// 'true' if nodes differ | |
bool compare(Node* n1, Node* n2) | |
{ | |
// Check if both nodes are of the same type | |
if (typeid(*n1) != typeid(*n2)) | |
{ | |
return true; // Nodes are of different types, so they are not equal | |
} | |
// Check if both are Literals | |
if (auto* lit1 = dynamic_cast<Literal*>(n1)) | |
{ | |
auto* lit2 = dynamic_cast<Literal*>(n2); | |
return lit1->value != lit2->value; // Compare their values | |
} | |
// Check if both are Parameters | |
if (auto* lit1 = dynamic_cast<Parameter*>(n1)) | |
{ | |
auto* lit2 = dynamic_cast<Parameter*>(n2); | |
return lit1->value != lit2->value; // Compare their values | |
} | |
// Check if both are AddNodes | |
if (auto* add1 = dynamic_cast<AddNode*>(n1)) | |
{ | |
auto* add2 = dynamic_cast<AddNode*>(n2); | |
return compare(add1->first, add2->first) || compare(add1->second, add2->second); | |
} | |
// Check if both are MultNodes | |
if (auto* mult1 = dynamic_cast<MultNode*>(n1)) | |
{ | |
auto* mult2 = dynamic_cast<MultNode*>(n2); | |
return compare(mult1->first, mult2->first) || compare(mult1->second, mult2->second); | |
} | |
// For unsupported types (shouldn't happen with the current setup) | |
return true; | |
} | |
Node* handleAddNode(AddNode* node, unsigned int s) | |
{ | |
if (s & Simplifications::addTwoLiterals) | |
{ | |
auto* x = dynamic_cast<Literal*>(node->first); | |
auto* y = dynamic_cast<Literal*>(node->second); | |
if (x && y) | |
{ | |
return new Literal(x->value + y->value); | |
} | |
} | |
if (s & Simplifications::addToZero) | |
{ | |
auto* x = dynamic_cast<Literal*>(node->first); | |
if (x && x->value == 0) | |
{ | |
return node->second; | |
} | |
auto* y = dynamic_cast<Literal*>(node->second); | |
if (y && y->value == 0) | |
{ | |
return node->first; | |
} | |
} | |
if (s & Simplifications::factorExpressions) | |
{ | |
auto* x = dynamic_cast<MultNode*>(node->first); | |
auto* y = dynamic_cast<MultNode*>(node->second); | |
if (!x || !y) | |
{ | |
// TODO passer à la prochaine simplification | |
return node; | |
} | |
auto s1 = getLiteral(x); | |
auto s2 = getLiteral(y); | |
if (!s1 || !s2) | |
{ | |
// TODO passer à la prochaine simplification | |
return node; | |
} | |
const auto& [coeff1, expr1] = *s1; | |
const auto& [coeff2, expr2] = *s2; | |
if (!compare(expr1, expr2)) | |
{ | |
const double coeff = coeff1 + coeff2; | |
// Q : shall we force simplification here ? | |
return new MultNode(new Literal(coeff), expr1); | |
} | |
} | |
// by default, return the original AddNode | |
return node; | |
} | |
// Use case 1 : visit an existing tree | |
class NodeVisitor | |
{ | |
Node* visit(AddNode* add) | |
{ | |
return handleAddNode(add, Simplifications::all); | |
} | |
Node* visit(MultNode* node) | |
{ | |
// TODO | |
return node; | |
} | |
}; | |
// Use case 2 : simplification at creation | |
Node* addSimplify(Node* n1, Node* n2) | |
{ | |
auto* add = new AddNode(n1, n2); | |
return handleAddNode(add, Simplifications::all); | |
} | |
#include <cassert> | |
Node* add3(double a, double b, double c) | |
{ | |
return addSimplify(addSimplify(new Literal(a), new Literal(b)), new Literal(c)); | |
} | |
int main() | |
{ | |
{ | |
Node* add = addSimplify(new Literal(3), new Literal(4)); | |
auto* lit = dynamic_cast<Literal*>(add); | |
assert(lit); | |
assert(lit->value == 7); | |
} | |
{ | |
Node* add = add3(3, 3, 0); | |
auto* r = dynamic_cast<Literal*>(add); | |
assert(r); | |
assert(r->value == 6); | |
} | |
{ | |
Node* add = add3(3, 3, 1); | |
auto* r = dynamic_cast<Literal*>(add); | |
assert(r); | |
assert(r->value == 7); | |
} | |
{ | |
Node* root = addSimplify(new MultNode(new Literal(3), new Parameter("ha")), | |
new MultNode(new Literal(4), new Parameter("ha"))); | |
auto* mult = dynamic_cast<MultNode*>(root); | |
assert(mult); | |
auto* lit = dynamic_cast<Literal*>(mult->first); | |
assert(lit && lit->value == 7); | |
auto* param = dynamic_cast<Parameter*>(mult->second); | |
assert(param && param->value == "ha"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment