Last active
March 30, 2024 14:02
-
-
Save sir-wabbit/c0efc9844629bd5d5d9e5655c5a9aaa5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <cmath> | |
#include <vector> | |
#include <cassert> | |
#include <memory> | |
enum NodeOp { | |
OP_CONST, | |
OP_ADD, | |
OP_MUL, | |
OP_EXP | |
}; | |
class NodeData; | |
class Node { | |
public: | |
// A shared pointer does reference-counting, | |
// so I don't have to worry about memory leaks. | |
std::shared_ptr<NodeData> const ptr; | |
double grad() const; | |
double value() const; | |
void set_value(double x); | |
// Sets all gradient variables to zero in the entire | |
// expression graph. | |
void reset_grad(); | |
// Does backpropagation of the gradients. | |
void backward(double g = 1.0); | |
// Forward propagation of the values, | |
// can be used on precomputed expression graphs. | |
double forward(); | |
// A constructor for constant values and variables. | |
Node(double c); | |
Node operator*(const Node& that) const; | |
Node operator+(const Node& that) const; | |
Node exp() const; | |
private: | |
Node(std::shared_ptr<NodeData> const ptr); | |
}; | |
class NodeData { | |
public: | |
NodeOp op; | |
double value; | |
// This variable will be mutated during the backpropagation pass. | |
double grad; | |
std::vector<Node> children; | |
}; | |
Node::Node(double c) | |
: ptr(std::shared_ptr<NodeData>( | |
new NodeData { OP_CONST, c, 0.0, {} } | |
)) { } | |
Node::Node(std::shared_ptr<NodeData> const ptr) : ptr(ptr) { }; | |
double Node::grad() const { | |
return ptr->grad; | |
} | |
double Node::value() const { | |
return ptr->value; | |
} | |
void Node::set_value(double x) { | |
ptr->value = x; | |
} | |
//////////////////////////////////////////////////////////////// | |
// Backpropagation | |
//////////////////////////////////////////////////////////////// | |
void Node::reset_grad() { | |
ptr->grad = 0.0; | |
for (auto c : ptr->children) { | |
c.reset_grad(); | |
} | |
} | |
void Node::backward(double g) { | |
ptr->grad += g; | |
if (ptr->op == OP_CONST) {} | |
else if (ptr->op == OP_ADD) { | |
for (auto c : ptr->children) { | |
c.backward(g); | |
} | |
} | |
else if (ptr->op == OP_MUL) { | |
// Easier to assume it's always 2 than deal with the other cases. | |
// If it is not 2 I would have to multiply all args but one, which is messy. | |
assert(ptr->children.size() == 2); | |
Node x = ptr->children[0]; | |
Node y = ptr->children[1]; | |
x.backward(g * y.value()); | |
y.backward(g * x.value()); | |
// Let f(x, y) = x * y | |
// and z = f(x, y) | |
// We know dL / dz = g | |
// then dL / dx = dL / dz * df / dx = g * y | |
} | |
else if (ptr->op == OP_EXP) { | |
assert(ptr->children.size() == 1); | |
Node x = ptr->children[0]; | |
x.backward(g * std::exp(x.value())); // | |
// Let f(x) = exp(x) | |
// and z = f(x) | |
// We know dL / dz = g | |
// then dL / dx = dL / dz * df / dx = g * exp(x) | |
} | |
else assert(false); | |
} | |
// Forward pass. | |
double Node::forward() { | |
if (ptr->op == OP_CONST) { | |
// do nothing. | |
} | |
else if (ptr->op == OP_ADD) { | |
double sum = 0.0; | |
for (auto c : ptr->children) { | |
sum += c.forward(); | |
} | |
set_value(sum); | |
} | |
else if (ptr->op == OP_MUL) { | |
double product = 1.0; | |
for (auto c : ptr->children) { | |
product *= c.forward(); | |
} | |
set_value(product); | |
} | |
else if (ptr->op == OP_EXP) { | |
assert(ptr->children.size() == 1); | |
set_value(std::exp(ptr->children[0].value())); | |
} | |
else assert(false); | |
return value(); | |
} | |
//////////////////////////////////////////////////////////////// | |
// Operators | |
//////////////////////////////////////////////////////////////// | |
Node Node::operator*(const Node& that) const { | |
return Node { | |
std::shared_ptr<NodeData>( | |
new NodeData { OP_MUL, ptr->value * that.ptr->value, 0.0, { *this, that }} | |
) | |
}; | |
} | |
Node Node::operator+(const Node& that) const { | |
return Node { | |
std::shared_ptr<NodeData>( | |
new NodeData { OP_ADD, ptr->value + that.ptr->value, 0.0, { *this, that }} | |
) | |
}; | |
} | |
Node Node::exp() const { | |
return Node { | |
std::shared_ptr<NodeData>( | |
new NodeData { OP_EXP, std::exp(ptr->value), 0.0, { *this }} | |
) | |
}; | |
} | |
Node operator+(double x, Node& n) { | |
return Node(x) + n; | |
} | |
Node operator+(Node& n, double x) { | |
return Node(x) + n; | |
} | |
Node operator*(double x, Node& n) { | |
return Node(x) * n; | |
} | |
Node operator*(Node& n, double x) { | |
return Node(x) * n; | |
} | |
//////////////////////////////////////////////////////////////// | |
int main() { | |
Node x(2); | |
Node y(1); | |
// Precomputed expression graph. | |
// It is not *necessary* to precompute graphs, see below. | |
Node l0 = 3 * x * x + y * y + x * y + 10; | |
for (int i = 0; i < 30; i++) { | |
// Some loss function with a global minimum. | |
Node loss = 3 * x * x + y * y + x * y + 10; | |
// Check that precomputed expression graph produces the same result. | |
l0.forward(); | |
assert(std::abs(l0.value() - loss.value()) < 0.00001); | |
loss.reset_grad(); | |
// Backpropagation step. | |
loss.backward(1.0); | |
// One SGD step. | |
x.set_value(x.value() - x.grad() * 0.01); | |
y.set_value(y.value() - y.grad() * 0.01); | |
std::cout << i << " " << loss.value() << " " << x.grad() << " " << y.grad() << std::endl; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment