Skip to content

Instantly share code, notes, and snippets.

@sir-wabbit
Last active November 4, 2024 17:58
Show Gist options
  • Save sir-wabbit/cd15c0e727c81e17d3b13a2cc5777893 to your computer and use it in GitHub Desktop.
Save sir-wabbit/cd15c0e727c81e17d3b13a2cc5777893 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <cmath>
#include <vector>
#include <cassert>
#include <memory>
enum NodeOp {
OP_CONST,
OP_ADD,
OP_MUL,
OP_EXP,
OP_DIV
};
class NodeData;
class Node {
public:
// A shared pointer does reference-counting,
// so I don't have to worry about memory leaks.
std::shared_ptr<NodeData> const ptr;
double grad() const;
double value() const;
void set_value(double x);
// Sets all gradient variables to zero in the entire
// expression graph.
void reset_grad();
// Does backpropagation of the gradients.
void backward(double g = 1.0);
// Forward propagation of the values,
// can be used on precomputed expression graphs.
double forward();
// A constructor for constant values and variables.
Node(double c);
Node operator*(const Node& that) const;
Node operator+(const Node& that) const;
Node operator/(const Node& that) const;
Node exp() const;
Node operator-() const {
return *this * Node(-1.0);
}
private:
Node(std::shared_ptr<NodeData> const ptr);
};
class NodeData {
public:
NodeOp op;
double value;
// This variable will be mutated during the backpropagation pass.
double grad;
std::vector<Node> children;
};
Node::Node(double c)
: ptr(std::shared_ptr<NodeData>(
new NodeData { OP_CONST, c, 0.0, {} }
)) { }
Node::Node(std::shared_ptr<NodeData> const ptr) : ptr(ptr) { };
double Node::grad() const {
return ptr->grad;
}
double Node::value() const {
return ptr->value;
}
void Node::set_value(double x) {
ptr->value = x;
}
////////////////////////////////////////////////////////////////
// Backpropagation
////////////////////////////////////////////////////////////////
void Node::reset_grad() {
ptr->grad = 0.0;
for (auto c : ptr->children) {
c.reset_grad();
}
}
void Node::backward(double g) {
ptr->grad += g;
if (ptr->op == OP_CONST) {}
else if (ptr->op == OP_ADD) {
for (auto c : ptr->children) {
c.backward(g);
}
}
else if (ptr->op == OP_MUL) {
// Easier to assume it's always 2 than deal with the other cases.
// If it is not 2 I would have to multiply all args but one, which is messy.
assert(ptr->children.size() == 2);
Node x = ptr->children[0];
Node y = ptr->children[1];
x.backward(g * y.value());
y.backward(g * x.value());
// Let f(x, y) = x * y
// and z = f(x, y)
// We know dL / dz = g
// then dL / dx = dL / dz * df / dx = g * y
}
else if (ptr->op == OP_DIV) {
assert(ptr->children.size() == 2);
Node x = ptr->children[0];
Node y = ptr->children[1];
// Let f(x, y) = x / y
// and z = f(x, y)
// We know dL / dz = g
// then dL / dx = dL / dz * df / dx = g / y
// and dL / dy = dL / dz * df / dy = -g * x / y^2
x.backward(g / y.value());
y.backward(-g * x.value() / (y.value() * y.value()));
}
else if (ptr->op == OP_EXP) {
assert(ptr->children.size() == 1);
Node x = ptr->children[0];
x.backward(g * std::exp(x.value())); //
// Let f(x) = exp(x)
// and z = f(x)
// We know dL / dz = g
// then dL / dx = dL / dz * df / dx = g * exp(x)
}
else assert(false);
}
// Forward pass.
double Node::forward() {
if (ptr->op == OP_CONST) {
// do nothing.
}
else if (ptr->op == OP_ADD) {
double sum = 0.0;
for (auto c : ptr->children) {
sum += c.forward();
}
set_value(sum);
}
else if (ptr->op == OP_MUL) {
double product = 1.0;
for (auto c : ptr->children) {
product *= c.forward();
}
set_value(product);
}
else if (ptr->op == OP_DIV) {
assert(ptr->children.size() == 2);
double x = ptr->children[0].forward();
double y = ptr->children[1].forward();
set_value(x / y);
}
else if (ptr->op == OP_EXP) {
assert(ptr->children.size() == 1);
set_value(std::exp(ptr->children[0].value()));
}
else assert(false);
return value();
}
////////////////////////////////////////////////////////////////
// Operators
////////////////////////////////////////////////////////////////
Node Node::operator*(const Node& that) const {
return Node {
std::shared_ptr<NodeData>(
new NodeData { OP_MUL, ptr->value * that.ptr->value, 0.0, { *this, that }}
)
};
}
Node Node::operator+(const Node& that) const {
return Node {
std::shared_ptr<NodeData>(
new NodeData { OP_ADD, ptr->value + that.ptr->value, 0.0, { *this, that }}
)
};
}
Node Node::operator/(const Node& that) const {
return Node {
std::shared_ptr<NodeData>(
new NodeData { OP_DIV, ptr->value / that.ptr->value, 0.0, { *this, that }}
)
};
}
Node Node::exp() const {
return Node {
std::shared_ptr<NodeData>(
new NodeData { OP_EXP, std::exp(ptr->value), 0.0, { *this }}
)
};
}
Node operator+(double x, const Node& n) {
return Node(x) + n;
}
Node operator+(const Node& n, double x) {
return Node(x) + n;
}
Node operator*(double x, const Node& n) {
return Node(x) * n;
}
Node operator*(const Node& n, double x) {
return Node(x) * n;
}
Node operator/(double x, const Node& n) {
return Node(x) / n;
}
////////////////////////////////////////////////////////////////
Node sigmoid(Node x) {
return 1.0 / (1.0 + (-x).exp());
}
int main() {
Node w13(0.1);
Node w14(0.4);
Node w24(0.6);
Node w23(0.8);
Node w35(0.3);
Node w45(0.9);
double x1 = 0.35;
double x2 = 0.9;
Node i3 = x1 * w13 + x2 * w23;
Node i4 = x1 * w14 + x2 * w24;
Node h3 = sigmoid(i3);
Node h4 = sigmoid(i4);
Node i5 = h3 * w35 + h4 * w45;
Node o = sigmoid(i5);
o.backward(1.0);
std::cout << "w13 = " << w13.value() << " dL/dw13 = " << w13.grad() << std::endl;
std::cout << "w14 = " << w14.value() << " dL/dw14 = " << w14.grad() << std::endl;
std::cout << "w23 = " << w23.value() << " dL/dw23 = " << w23.grad() << std::endl;
std::cout << "w24 = " << w24.value() << " dL/dw24 = " << w24.grad() << std::endl;
std::cout << "w35 = " << w35.value() << " dL/dw35 = " << w35.grad() << std::endl;
std::cout << "w45 = " << w45.value() << " dL/dw45 = " << w45.grad() << std::endl;
std::cout << "i3 = " << i3.value() << " dL/di3 = " << i3.grad() << std::endl;
std::cout << "i4 = " << i4.value() << " dL/di4 = " << i4.grad() << std::endl;
std::cout << "h3 = " << h3.value() << " dL/dh3 = " << h3.grad() << std::endl;
std::cout << "h4 = " << h4.value() << " dL/dh4 = " << h4.grad() << std::endl;
std::cout << "i5 = " << i5.value() << " dL/di5 = " << i5.grad() << std::endl;
std::cout << "o = " << o.value() << " dL/do = " << o.grad() << std::endl;
// for (int i = 0; i < 30; i++) {
// // Some loss function with a global minimum.
// // Check that precomputed expression graph produces the same result.
// l0.forward();
// assert(std::abs(l0.value() - loss.value()) < 0.00001);
// loss.reset_grad();
// // Backpropagation step.
// loss.backward(1.0);
// // One SGD step.
// x.set_value(x.value() - x.grad() * 0.01);
// y.set_value(y.value() - y.grad() * 0.01);
// std::cout << i << " " << loss.value() << " " << x.grad() << " " << y.grad() << std::endl;
// }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment