Skip to content

Instantly share code, notes, and snippets.

@dramforever
Last active December 12, 2018 18:17
Show Gist options
  • Save dramforever/e7ec9b7d424dae15a7818fef89b8a82f to your computer and use it in GitHub Desktop.
Save dramforever/e7ec9b7d424dae15a7818fef89b8a82f to your computer and use it in GitHub Desktop.
RAIINN (RAII Neural Network) (or maybe just the gradient part)
#include <iostream>
#include <memory>
#include <cmath>
struct Expr {
double grad, value;
void propagate(double delt) {
grad += delt;
}
Expr(double value_): value(value_) {}
virtual ~Expr() {}
};
typedef std::shared_ptr<Expr> ExprP;
struct Var : public Expr {
~Var() {}
Var(double value_): Expr(value_) {}
};
ExprP var(double x) {
return std::static_pointer_cast<Expr>(std::make_shared<Var>(x));
}
struct Mul : public Expr {
ExprP lhs, rhs;
Mul(ExprP lhs_, ExprP rhs_):
lhs(lhs_), rhs(rhs_), Expr(lhs_->value * rhs_->value) {}
~Mul() {
lhs->propagate(grad * rhs->value);
rhs->propagate(grad * lhs->value);
}
};
ExprP mul(ExprP lhs, ExprP rhs) {
return std::static_pointer_cast<Expr>(std::make_shared<Mul>(lhs, rhs));
}
struct Add : public Expr {
ExprP lhs, rhs;
Add(ExprP lhs_, ExprP rhs_):
lhs(lhs_), rhs(rhs_), Expr(lhs_->value + rhs_->value) {}
~Add() {
lhs->propagate(grad);
rhs->propagate(grad);
}
};
ExprP add(ExprP lhs, ExprP rhs) {
return std::static_pointer_cast<Expr>(std::make_shared<Add>(lhs, rhs));
}
struct Sigmoid : public Expr {
ExprP arg;
Sigmoid(ExprP arg_):
arg(arg_), Expr(1 / (1 + exp(- arg_->value))) {}
~Sigmoid() {
arg->propagate(grad * value * (1 - value));
}
};
ExprP sigmoid(ExprP arg) {
return std::static_pointer_cast<Expr>(std::make_shared<Sigmoid>(arg));
}
ExprP compute(const std::array<ExprP, 10>& pr, double x, double y, double actual) {
ExprP p1z0 = sigmoid(add( mul(pr[0], var(x)), pr[1] ));
ExprP p1z1 = sigmoid(add( mul(pr[2], var(y)), pr[3] ));
ExprP p2z0 = sigmoid(add( mul(pr[4], p1z0), mul(pr[5], p1z1) ));
ExprP p2z1 = sigmoid(add( mul(pr[6], p1z0), mul(pr[7], p1z1) ));
ExprP output = sigmoid(add( mul(pr[8], p2z0), mul(pr[9], p2z1) ));
ExprP error = add(output, var(-actual));
ExprP square = mul(error, error);
return square;
}
int main() {
std::array<ExprP, 10> pr {
var(0.3), var(0.4), var(0.5), var(0.1),
var(0.2), var(-0.1), var(-0.3), var(0.4),
var(1), var(-1)
};
compute(pr, 3, -1, 1)->propagate(1);
std::cout << "[";
for (const auto &e : pr) {
std::cout << " " << e->grad;
}
std::cout << " ]\n";
// [ -0.0153562 -0.00511874 -0.00731039 0.00731039 -0.0477443 -0.0243822 0.0478403 0.0244312 -0.129056 -0.117345 ]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment