Last active
April 17, 2022 22:36
-
-
Save Puriney/072a37ea8a181f0b6168 to your computer and use it in GitHub Desktop.
RNN to learn binary addition implemented in Rcpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// Yun Yan | |
// | |
// [[Rcpp::plugins(cpp11)]] | |
#include <bitset> | |
#include <unordered_set> | |
#include <RcppArmadillo.h> | |
// [[Rcpp::depends(RcppArmadillo)]] | |
// #include <RcppEigen.h> | |
//// [[Rcpp::depends(RcppEigen)]] | |
using namespace Rcpp; | |
// using namespace Eigen; | |
using namespace arma; | |
// Using Rcpp to reproduce "RNN in python" | |
//' Sigmoid function | |
//' | |
//' @param x A numeric vector. | |
//' @export | |
// [[Rcpp::export]] | |
NumericVector sigmoid(NumericVector x) { | |
NumericVector ret = 1 / (1 + exp(-1 * x)); | |
return ret; | |
} | |
//' First derivative of sigmoid function | |
//' | |
//' @param x A numeric vector. | |
//' @export | |
// [[Rcpp::export]] | |
NumericVector sigmoid_deriv(NumericVector x) { | |
return x * (1 - x); | |
} | |
// [[Rcpp::export]] | |
NumericMatrix dotprodmm(NumericMatrix a, NumericMatrix b) { | |
// dotprod = matrix %*% matrix | |
mat aa = as<mat>(a); | |
mat bb = as<mat>(b); | |
mat cc = aa * bb; | |
return(wrap(cc)); | |
} | |
// [[Rcpp::export]] | |
NumericMatrix dotprodvm(NumericVector a, NumericMatrix b) { | |
// dotprod = vector %*% matrix | |
vec a2 = as<vec>(a); // vector to 1-by-n matrix | |
int n = a.size(); | |
mat aa; | |
aa.insert_cols(0, a2); | |
aa.reshape(1, n); | |
mat bb = as<mat>(b); | |
mat cc = aa * bb; | |
return(wrap(cc)); | |
} | |
// [[Rcpp::export]] | |
NumericVector m2v(NumericMatrix x){ | |
// matrix to n-by-1 pseudo-vector | |
mat mx = as<mat>(x); | |
vec ret = vectorise(mx); | |
return(wrap(ret)); | |
} | |
// [[Rcpp::export]] | |
NumericMatrix v2m(NumericVector v, int nrow, int ncol){ | |
NumericMatrix out(nrow, ncol); | |
for (int i = 0; i < v.size(); i++ ){ | |
out[i] = v[i]; | |
} | |
return out; | |
} | |
// [[Rcpp::export]] | |
NumericMatrix InitialMatrix(int nrow, int ncol, bool fill = false){ | |
int n = nrow * ncol; | |
NumericMatrix ret(nrow, ncol); | |
NumericVector::iterator i = ret.begin(); | |
NumericVector::iterator j; | |
NumericVector val; | |
if (fill == true) { | |
val = rep(NumericVector::create(0.1), n); | |
} else { | |
val = runif(n); | |
} | |
for (j = val.begin(); j != val.end(); ++i, ++j) { | |
*i = *j; | |
} | |
return clone(ret); | |
} | |
// [[Rcpp::export]] | |
NumericMatrix dotrans(NumericMatrix x){ | |
mat xx = as<mat>(x); | |
return(wrap(xx.t())); | |
} | |
// [[Rcpp::export]] | |
List RNN_Train(NumericVector A, NumericVector B, bool verbose){ | |
// network basic parameters | |
double alpha = 0.1; | |
int input_dim = 2; | |
int hidden_dim = 16; | |
int output_dim = 1; | |
int const bin_dim = 8; | |
// network weights; [final output] | |
NumericMatrix syn0 = 2 * InitialMatrix(input_dim, hidden_dim) - 1; | |
NumericMatrix syn1 = 2 * InitialMatrix(hidden_dim, output_dim) - 1; | |
NumericMatrix synh = 2 * InitialMatrix(hidden_dim, hidden_dim) - 1; | |
// syn0 = syn0 * 0 + 0.1; | |
// syn1 = syn1 * 0 + 0.1; | |
// synh = synh * 0 + 0.1; | |
NumericMatrix syn0_up = InitialMatrix(input_dim, hidden_dim) * 0; | |
NumericMatrix syn1_up = InitialMatrix(hidden_dim, output_dim) * 0; | |
NumericMatrix synh_up = InitialMatrix(hidden_dim, hidden_dim) * 0; | |
// correct answer | |
NumericVector C = A + B; | |
int N = A.size(); | |
for (int smp_i = 0; smp_i < N; smp_i ++) { // iterate each sample | |
if (verbose) { | |
Rcout << "## Sample: " << smp_i << std::endl; | |
Rcout << "syn0(0,0) = " << syn0(0, 0) << std::endl; | |
} | |
int aint = A[smp_i]; | |
int bint = B[smp_i]; | |
int cint = C[smp_i]; | |
std::bitset<bin_dim> a(aint); | |
std::bitset<bin_dim> b(bint); | |
std::bitset<bin_dim> c(cint); | |
NumericVector cHat(bin_dim); // RNN learning binary digit | |
double err_sum = 0.0; | |
List l2_deltas; | |
List l1_vals; | |
l1_vals.push_back(rep(NumericVector::create(0), hidden_dim)); | |
// FP begins | |
if (verbose ) Rcout << "-- FP" << std::endl; | |
for (std::size_t i = 0; i < a.size(); ++i) { // iterate time-steps | |
// input and output of each time-step | |
NumericVector x = NumericVector::create(a[i], // Notice bit-set operator | |
b[i]); // right-most | |
NumericVector y = NumericVector::create(c[i]); | |
// hidden_layer ~ input + prev_hidden | |
NumericVector l1 = sigmoid(dotprodvm(x, syn0) + | |
dotprodvm(as<NumericVector>(l1_vals[l1_vals.size() - 1]), synh)); | |
// output layer | |
NumericVector l2 = sigmoid(dotprodvm(l1, syn1)); | |
// error at output layer | |
NumericVector l2_err = y - l2; | |
err_sum += sum(abs(l2_err)); | |
l2_deltas.push_back(l2_err * sigmoid_deriv(l2)); | |
if (verbose) { | |
Rcout << ">> sample " << smp_i << "pos" << i << "=" << x << std::endl; | |
Rcout << a[bin_dim - i - 1] << std::endl; | |
Rcout << l2_err << std::endl; | |
} | |
// save output to be displayed | |
cHat[bin_dim -i -1] = round(l2[0]); | |
// save hidden layer to be used for next time-step | |
l1_vals.push_back(clone(l1)); | |
} | |
// FP ends | |
if (verbose) { | |
for (List::iterator li = l1_vals.begin(); li != l1_vals.end(); li ++){ | |
Rcout << "l1 vals: " << as<NumericVector>(*li) << std::endl; | |
} | |
} | |
// BP begins | |
if (verbose) Rcout << "-- BP" << std::endl; | |
// layer-1 at "next-time"-step | |
NumericVector future_l1_delta = rep(NumericVector::create(0), hidden_dim); | |
for (std::size_t i = 0; i < a.size(); ++i) { | |
NumericVector x = NumericVector::create(a[bin_dim - i - 1], | |
b[bin_dim - i - 1]); | |
NumericVector l1 = l1_vals[l1_vals.size() - i - 1]; | |
NumericVector l1_prev = l1_vals[l1_vals.size() - i -2]; | |
// delta at output layer | |
NumericVector l2_delta = l2_deltas[l2_deltas.size() -i - 1]; | |
// delta at hidden layer | |
NumericVector l1_delta = (dotprodvm(l2_delta, dotrans(syn1)) + | |
dotprodvm(future_l1_delta, dotrans(synh))) * | |
sigmoid_deriv(l1); | |
// collect updates untill all time-steps finished | |
syn1_up += dotprodmm(v2m(l1, l1.size(), 1), | |
v2m(l2_delta, 1, l2_delta.size())); | |
synh_up += dotprodmm(v2m(l1_prev, l1_prev.size(), 1), | |
v2m(l1_delta, 1, l1_delta.size())); | |
syn0_up += dotprodmm(v2m(x, x.size(), 1), | |
v2m(l1_delta, 1, l1_delta.size())); | |
future_l1_delta = l1_delta; | |
if (verbose){ | |
Rcout << "bp input at pos: " << i << "=" << x << std::endl; | |
Rcout << "future delta: " << future_l1_delta << std::endl; | |
} | |
} | |
// BP ends here | |
// Update netowrk parameters | |
syn0 += (syn0_up * alpha); | |
syn1 += (syn1_up * alpha); | |
synh += (synh_up * alpha); | |
syn0_up = InitialMatrix(input_dim, hidden_dim) * 0; | |
syn1_up = InitialMatrix(hidden_dim, output_dim) * 0; | |
synh_up = InitialMatrix(hidden_dim, hidden_dim) * 0; | |
if (verbose) { | |
Rcout << "!!After" << std::endl; | |
Rcout << syn0(0, 0) << syn0(0, 1) << syn0(1, 0) << syn0(1, 1) << std::endl; | |
Rcout << syn1(0, 0) << syn1(1, 0) << syn1(2, 0) << syn1(0, 3) << std::endl; | |
Rcout << synh(0, 0) << synh(0, 1) << synh(1, 0) << synh(1, 1) << std::endl; | |
Rcout << syn0_up(1, 2) << std::endl; | |
} | |
if (!verbose && smp_i % 1000 == 0 ){ | |
Rcout << "Sample " << smp_i << std::endl; | |
Rcout << "Overall Error: " << err_sum << std::endl; | |
Rcout << "Pred: [" << cHat << "]" << std::endl; | |
Rcout << "True: " << c.to_string() << std::endl; | |
double cHatInt = 0.0; | |
for (int i = 0; i < cHat.size(); i ++){ | |
cHatInt += pow(2.0, cHat.size() - i - 1) * cHat[i]; | |
} | |
Rcout << "Calc Binary: " << std::endl << a << std::endl << b << std::endl; | |
Rcout << "Calc Decimal: " << aint << "+" << bint << "=" << cHatInt << std::endl; | |
Rcout << "----" << std::endl; | |
} | |
} | |
List syn = List::create(_["syn0"] = syn0, | |
_["syn1"] = syn1, | |
_["synh"] = synh); | |
return syn; | |
} | |
/*** R | |
require(Rcpp) | |
require(RcppArmadillo) | |
require(RcppEigen) | |
set.seed(2016) | |
max_num <- 2^8 -1 | |
n <- 10000 | |
A <- c(sample(1:(max_num/2), n, replace = T), 1) | |
B <- c(sample(1:(max_num/2), n, replace = T), 1) | |
vFlag <- F | |
rnn_fit <- RNN_Train(A, B, vFlag) | |
*/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
https://q-aps.princeton.edu/sites/default/files/q-aps/files/slides_day4_am.pdf