Skip to content

Instantly share code, notes, and snippets.

@flomnes
Last active September 4, 2024 16:56
Show Gist options
  • Save flomnes/68dfc8dfd57cf38317f116dd222eca87 to your computer and use it in GitHub Desktop.
Save flomnes/68dfc8dfd57cf38317f116dd222eca87 to your computer and use it in GitHub Desktop.
AST Simplification
#include <optional>
#include <string>
#include <utility>
class Node
{
public:
virtual ~Node() = default;
};
class Literal: public Node
{
public:
Literal(double v):
value(v)
{
}
double value;
};
class Parameter: public Node
{
public:
Parameter(const std::string& v):
value(v)
{
}
std::string value;
};
class AddNode: public std::pair<Node*, Node*>, public Node
{
public:
using std::pair<Node*, Node*>::pair;
};
class MultNode: public std::pair<Node*, Node*>, public Node
{
public:
using std::pair<Node*, Node*>::pair;
};
namespace Simplifications
{
constexpr unsigned int addTwoLiterals = 1UL << 0;
constexpr unsigned int addToZero = 1UL << 1;
constexpr unsigned int factorExpressions = 1UL << 2;
constexpr unsigned int all = addTwoLiterals | addToZero | factorExpressions;
}; // namespace Simplifications
std::optional<std::pair<double, Node*>> getLiteral(MultNode* n)
{
auto* x1 = dynamic_cast<Literal*>(n->first);
auto* x2 = dynamic_cast<Literal*>(n->second);
if (x1)
{
return {{x1->value, n->second}};
}
if (x2)
{
return {{x2->value, n->first}};
}
return {};
}
// return 'false' if nodes compare equal
// 'true' if nodes differ
bool compare(Node* n1, Node* n2)
{
// Check if both nodes are of the same type
if (typeid(*n1) != typeid(*n2))
{
return true; // Nodes are of different types, so they are not equal
}
// Check if both are Literals
if (auto* lit1 = dynamic_cast<Literal*>(n1))
{
auto* lit2 = dynamic_cast<Literal*>(n2);
return lit1->value != lit2->value; // Compare their values
}
// Check if both are Parameters
if (auto* lit1 = dynamic_cast<Parameter*>(n1))
{
auto* lit2 = dynamic_cast<Parameter*>(n2);
return lit1->value != lit2->value; // Compare their values
}
// Check if both are AddNodes
if (auto* add1 = dynamic_cast<AddNode*>(n1))
{
auto* add2 = dynamic_cast<AddNode*>(n2);
return compare(add1->first, add2->first) || compare(add1->second, add2->second);
}
// Check if both are MultNodes
if (auto* mult1 = dynamic_cast<MultNode*>(n1))
{
auto* mult2 = dynamic_cast<MultNode*>(n2);
return compare(mult1->first, mult2->first) || compare(mult1->second, mult2->second);
}
// For unsupported types (shouldn't happen with the current setup)
return true;
}
Node* handleAddNode(AddNode* node, unsigned int s)
{
if (s & Simplifications::addTwoLiterals)
{
auto* x = dynamic_cast<Literal*>(node->first);
auto* y = dynamic_cast<Literal*>(node->second);
if (x && y)
{
return new Literal(x->value + y->value);
}
}
if (s & Simplifications::addToZero)
{
auto* x = dynamic_cast<Literal*>(node->first);
if (x && x->value == 0)
{
return node->second;
}
auto* y = dynamic_cast<Literal*>(node->second);
if (y && y->value == 0)
{
return node->first;
}
}
if (s & Simplifications::factorExpressions)
{
auto* x = dynamic_cast<MultNode*>(node->first);
auto* y = dynamic_cast<MultNode*>(node->second);
if (!x || !y)
{
// TODO passer à la prochaine simplification
return node;
}
auto s1 = getLiteral(x);
auto s2 = getLiteral(y);
if (!s1 || !s2)
{
// TODO passer à la prochaine simplification
return node;
}
const auto& [coeff1, expr1] = *s1;
const auto& [coeff2, expr2] = *s2;
if (!compare(expr1, expr2))
{
const double coeff = coeff1 + coeff2;
// Q : shall we force simplification here ?
return new MultNode(new Literal(coeff), expr1);
}
}
// by default, return the original AddNode
return node;
}
// Use case 1 : visit an existing tree
class NodeVisitor
{
Node* visit(AddNode* add)
{
return handleAddNode(add, Simplifications::all);
}
Node* visit(MultNode* node)
{
// TODO
return node;
}
};
// Use case 2 : simplification at creation
Node* addSimplify(Node* n1, Node* n2)
{
auto* add = new AddNode(n1, n2);
return handleAddNode(add, Simplifications::all);
}
#include <cassert>
Node* add3(double a, double b, double c)
{
return addSimplify(addSimplify(new Literal(a), new Literal(b)), new Literal(c));
}
int main()
{
{
Node* add = addSimplify(new Literal(3), new Literal(4));
auto* lit = dynamic_cast<Literal*>(add);
assert(lit);
assert(lit->value == 7);
}
{
Node* add = add3(3, 3, 0);
auto* r = dynamic_cast<Literal*>(add);
assert(r);
assert(r->value == 6);
}
{
Node* add = add3(3, 3, 1);
auto* r = dynamic_cast<Literal*>(add);
assert(r);
assert(r->value == 7);
}
{
Node* root = addSimplify(new MultNode(new Literal(3), new Parameter("ha")),
new MultNode(new Literal(4), new Parameter("ha")));
auto* mult = dynamic_cast<MultNode*>(root);
assert(mult);
auto* lit = dynamic_cast<Literal*>(mult->first);
assert(lit && lit->value == 7);
auto* param = dynamic_cast<Parameter*>(mult->second);
assert(param && param->value == "ha");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment