Last active
November 4, 2024 17:58
-
-
Save sir-wabbit/cd15c0e727c81e17d3b13a2cc5777893 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <cmath> | |
#include <vector> | |
#include <cassert> | |
#include <memory> | |
enum NodeOp { | |
OP_CONST, | |
OP_ADD, | |
OP_MUL, | |
OP_EXP, | |
OP_DIV | |
}; | |
class NodeData; | |
class Node { | |
public: | |
// A shared pointer does reference-counting, | |
// so I don't have to worry about memory leaks. | |
std::shared_ptr<NodeData> const ptr; | |
double grad() const; | |
double value() const; | |
void set_value(double x); | |
// Sets all gradient variables to zero in the entire | |
// expression graph. | |
void reset_grad(); | |
// Does backpropagation of the gradients. | |
void backward(double g = 1.0); | |
// Forward propagation of the values, | |
// can be used on precomputed expression graphs. | |
double forward(); | |
// A constructor for constant values and variables. | |
Node(double c); | |
Node operator*(const Node& that) const; | |
Node operator+(const Node& that) const; | |
Node operator/(const Node& that) const; | |
Node exp() const; | |
Node operator-() const { | |
return *this * Node(-1.0); | |
} | |
private: | |
Node(std::shared_ptr<NodeData> const ptr); | |
}; | |
class NodeData { | |
public: | |
NodeOp op; | |
double value; | |
// This variable will be mutated during the backpropagation pass. | |
double grad; | |
std::vector<Node> children; | |
}; | |
Node::Node(double c) | |
: ptr(std::shared_ptr<NodeData>( | |
new NodeData { OP_CONST, c, 0.0, {} } | |
)) { } | |
Node::Node(std::shared_ptr<NodeData> const ptr) : ptr(ptr) { }; | |
double Node::grad() const { | |
return ptr->grad; | |
} | |
double Node::value() const { | |
return ptr->value; | |
} | |
void Node::set_value(double x) { | |
ptr->value = x; | |
} | |
//////////////////////////////////////////////////////////////// | |
// Backpropagation | |
//////////////////////////////////////////////////////////////// | |
void Node::reset_grad() { | |
ptr->grad = 0.0; | |
for (auto c : ptr->children) { | |
c.reset_grad(); | |
} | |
} | |
void Node::backward(double g) { | |
ptr->grad += g; | |
if (ptr->op == OP_CONST) {} | |
else if (ptr->op == OP_ADD) { | |
for (auto c : ptr->children) { | |
c.backward(g); | |
} | |
} | |
else if (ptr->op == OP_MUL) { | |
// Easier to assume it's always 2 than deal with the other cases. | |
// If it is not 2 I would have to multiply all args but one, which is messy. | |
assert(ptr->children.size() == 2); | |
Node x = ptr->children[0]; | |
Node y = ptr->children[1]; | |
x.backward(g * y.value()); | |
y.backward(g * x.value()); | |
// Let f(x, y) = x * y | |
// and z = f(x, y) | |
// We know dL / dz = g | |
// then dL / dx = dL / dz * df / dx = g * y | |
} | |
else if (ptr->op == OP_DIV) { | |
assert(ptr->children.size() == 2); | |
Node x = ptr->children[0]; | |
Node y = ptr->children[1]; | |
// Let f(x, y) = x / y | |
// and z = f(x, y) | |
// We know dL / dz = g | |
// then dL / dx = dL / dz * df / dx = g / y | |
// and dL / dy = dL / dz * df / dy = -g * x / y^2 | |
x.backward(g / y.value()); | |
y.backward(-g * x.value() / (y.value() * y.value())); | |
} | |
else if (ptr->op == OP_EXP) { | |
assert(ptr->children.size() == 1); | |
Node x = ptr->children[0]; | |
x.backward(g * std::exp(x.value())); // | |
// Let f(x) = exp(x) | |
// and z = f(x) | |
// We know dL / dz = g | |
// then dL / dx = dL / dz * df / dx = g * exp(x) | |
} | |
else assert(false); | |
} | |
// Forward pass. | |
double Node::forward() { | |
if (ptr->op == OP_CONST) { | |
// do nothing. | |
} | |
else if (ptr->op == OP_ADD) { | |
double sum = 0.0; | |
for (auto c : ptr->children) { | |
sum += c.forward(); | |
} | |
set_value(sum); | |
} | |
else if (ptr->op == OP_MUL) { | |
double product = 1.0; | |
for (auto c : ptr->children) { | |
product *= c.forward(); | |
} | |
set_value(product); | |
} | |
else if (ptr->op == OP_DIV) { | |
assert(ptr->children.size() == 2); | |
double x = ptr->children[0].forward(); | |
double y = ptr->children[1].forward(); | |
set_value(x / y); | |
} | |
else if (ptr->op == OP_EXP) { | |
assert(ptr->children.size() == 1); | |
set_value(std::exp(ptr->children[0].value())); | |
} | |
else assert(false); | |
return value(); | |
} | |
//////////////////////////////////////////////////////////////// | |
// Operators | |
//////////////////////////////////////////////////////////////// | |
Node Node::operator*(const Node& that) const { | |
return Node { | |
std::shared_ptr<NodeData>( | |
new NodeData { OP_MUL, ptr->value * that.ptr->value, 0.0, { *this, that }} | |
) | |
}; | |
} | |
Node Node::operator+(const Node& that) const { | |
return Node { | |
std::shared_ptr<NodeData>( | |
new NodeData { OP_ADD, ptr->value + that.ptr->value, 0.0, { *this, that }} | |
) | |
}; | |
} | |
Node Node::operator/(const Node& that) const { | |
return Node { | |
std::shared_ptr<NodeData>( | |
new NodeData { OP_DIV, ptr->value / that.ptr->value, 0.0, { *this, that }} | |
) | |
}; | |
} | |
Node Node::exp() const { | |
return Node { | |
std::shared_ptr<NodeData>( | |
new NodeData { OP_EXP, std::exp(ptr->value), 0.0, { *this }} | |
) | |
}; | |
} | |
Node operator+(double x, const Node& n) { | |
return Node(x) + n; | |
} | |
Node operator+(const Node& n, double x) { | |
return Node(x) + n; | |
} | |
Node operator*(double x, const Node& n) { | |
return Node(x) * n; | |
} | |
Node operator*(const Node& n, double x) { | |
return Node(x) * n; | |
} | |
Node operator/(double x, const Node& n) { | |
return Node(x) / n; | |
} | |
//////////////////////////////////////////////////////////////// | |
Node sigmoid(Node x) { | |
return 1.0 / (1.0 + (-x).exp()); | |
} | |
int main() { | |
Node w13(0.1); | |
Node w14(0.4); | |
Node w24(0.6); | |
Node w23(0.8); | |
Node w35(0.3); | |
Node w45(0.9); | |
double x1 = 0.35; | |
double x2 = 0.9; | |
Node i3 = x1 * w13 + x2 * w23; | |
Node i4 = x1 * w14 + x2 * w24; | |
Node h3 = sigmoid(i3); | |
Node h4 = sigmoid(i4); | |
Node i5 = h3 * w35 + h4 * w45; | |
Node o = sigmoid(i5); | |
o.backward(1.0); | |
std::cout << "w13 = " << w13.value() << " dL/dw13 = " << w13.grad() << std::endl; | |
std::cout << "w14 = " << w14.value() << " dL/dw14 = " << w14.grad() << std::endl; | |
std::cout << "w23 = " << w23.value() << " dL/dw23 = " << w23.grad() << std::endl; | |
std::cout << "w24 = " << w24.value() << " dL/dw24 = " << w24.grad() << std::endl; | |
std::cout << "w35 = " << w35.value() << " dL/dw35 = " << w35.grad() << std::endl; | |
std::cout << "w45 = " << w45.value() << " dL/dw45 = " << w45.grad() << std::endl; | |
std::cout << "i3 = " << i3.value() << " dL/di3 = " << i3.grad() << std::endl; | |
std::cout << "i4 = " << i4.value() << " dL/di4 = " << i4.grad() << std::endl; | |
std::cout << "h3 = " << h3.value() << " dL/dh3 = " << h3.grad() << std::endl; | |
std::cout << "h4 = " << h4.value() << " dL/dh4 = " << h4.grad() << std::endl; | |
std::cout << "i5 = " << i5.value() << " dL/di5 = " << i5.grad() << std::endl; | |
std::cout << "o = " << o.value() << " dL/do = " << o.grad() << std::endl; | |
// for (int i = 0; i < 30; i++) { | |
// // Some loss function with a global minimum. | |
// // Check that precomputed expression graph produces the same result. | |
// l0.forward(); | |
// assert(std::abs(l0.value() - loss.value()) < 0.00001); | |
// loss.reset_grad(); | |
// // Backpropagation step. | |
// loss.backward(1.0); | |
// // One SGD step. | |
// x.set_value(x.value() - x.grad() * 0.01); | |
// y.set_value(y.value() - y.grad() * 0.01); | |
// std::cout << i << " " << loss.value() << " " << x.grad() << " " << y.grad() << std::endl; | |
// } | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment