Skip to content

Instantly share code, notes, and snippets.

@Puriney
Last active April 17, 2022 22:36
Show Gist options
  • Save Puriney/072a37ea8a181f0b6168 to your computer and use it in GitHub Desktop.
Save Puriney/072a37ea8a181f0b6168 to your computer and use it in GitHub Desktop.
RNN to learn binary addition implemented in Rcpp
//
// Yun Yan
//
// [[Rcpp::plugins(cpp11)]]
#include <bitset>
#include <unordered_set>
#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
// #include <RcppEigen.h>
//// [[Rcpp::depends(RcppEigen)]]
using namespace Rcpp;
// using namespace Eigen;
using namespace arma;
// Using Rcpp to reproduce "RNN in python"
//' Sigmoid function
//'
//' @param x A numeric vector.
//' @export
// [[Rcpp::export]]
NumericVector sigmoid(NumericVector x) {
NumericVector ret = 1 / (1 + exp(-1 * x));
return ret;
}
//' First derivative of sigmoid function
//'
//' @param x A numeric vector.
//' @export
// [[Rcpp::export]]
NumericVector sigmoid_deriv(NumericVector x) {
return x * (1 - x);
}
// [[Rcpp::export]]
NumericMatrix dotprodmm(NumericMatrix a, NumericMatrix b) {
// dotprod = matrix %*% matrix
mat aa = as<mat>(a);
mat bb = as<mat>(b);
mat cc = aa * bb;
return(wrap(cc));
}
// [[Rcpp::export]]
NumericMatrix dotprodvm(NumericVector a, NumericMatrix b) {
// dotprod = vector %*% matrix
vec a2 = as<vec>(a); // vector to 1-by-n matrix
int n = a.size();
mat aa;
aa.insert_cols(0, a2);
aa.reshape(1, n);
mat bb = as<mat>(b);
mat cc = aa * bb;
return(wrap(cc));
}
// [[Rcpp::export]]
NumericVector m2v(NumericMatrix x){
// matrix to n-by-1 pseudo-vector
mat mx = as<mat>(x);
vec ret = vectorise(mx);
return(wrap(ret));
}
// [[Rcpp::export]]
NumericMatrix v2m(NumericVector v, int nrow, int ncol){
NumericMatrix out(nrow, ncol);
for (int i = 0; i < v.size(); i++ ){
out[i] = v[i];
}
return out;
}
// [[Rcpp::export]]
NumericMatrix InitialMatrix(int nrow, int ncol, bool fill = false){
int n = nrow * ncol;
NumericMatrix ret(nrow, ncol);
NumericVector::iterator i = ret.begin();
NumericVector::iterator j;
NumericVector val;
if (fill == true) {
val = rep(NumericVector::create(0.1), n);
} else {
val = runif(n);
}
for (j = val.begin(); j != val.end(); ++i, ++j) {
*i = *j;
}
return clone(ret);
}
// [[Rcpp::export]]
NumericMatrix dotrans(NumericMatrix x){
mat xx = as<mat>(x);
return(wrap(xx.t()));
}
// [[Rcpp::export]]
List RNN_Train(NumericVector A, NumericVector B, bool verbose){
// network basic parameters
double alpha = 0.1;
int input_dim = 2;
int hidden_dim = 16;
int output_dim = 1;
int const bin_dim = 8;
// network weights; [final output]
NumericMatrix syn0 = 2 * InitialMatrix(input_dim, hidden_dim) - 1;
NumericMatrix syn1 = 2 * InitialMatrix(hidden_dim, output_dim) - 1;
NumericMatrix synh = 2 * InitialMatrix(hidden_dim, hidden_dim) - 1;
// syn0 = syn0 * 0 + 0.1;
// syn1 = syn1 * 0 + 0.1;
// synh = synh * 0 + 0.1;
NumericMatrix syn0_up = InitialMatrix(input_dim, hidden_dim) * 0;
NumericMatrix syn1_up = InitialMatrix(hidden_dim, output_dim) * 0;
NumericMatrix synh_up = InitialMatrix(hidden_dim, hidden_dim) * 0;
// correct answer
NumericVector C = A + B;
int N = A.size();
for (int smp_i = 0; smp_i < N; smp_i ++) { // iterate each sample
if (verbose) {
Rcout << "## Sample: " << smp_i << std::endl;
Rcout << "syn0(0,0) = " << syn0(0, 0) << std::endl;
}
int aint = A[smp_i];
int bint = B[smp_i];
int cint = C[smp_i];
std::bitset<bin_dim> a(aint);
std::bitset<bin_dim> b(bint);
std::bitset<bin_dim> c(cint);
NumericVector cHat(bin_dim); // RNN learning binary digit
double err_sum = 0.0;
List l2_deltas;
List l1_vals;
l1_vals.push_back(rep(NumericVector::create(0), hidden_dim));
// FP begins
if (verbose ) Rcout << "-- FP" << std::endl;
for (std::size_t i = 0; i < a.size(); ++i) { // iterate time-steps
// input and output of each time-step
NumericVector x = NumericVector::create(a[i], // Notice bit-set operator
b[i]); // right-most
NumericVector y = NumericVector::create(c[i]);
// hidden_layer ~ input + prev_hidden
NumericVector l1 = sigmoid(dotprodvm(x, syn0) +
dotprodvm(as<NumericVector>(l1_vals[l1_vals.size() - 1]), synh));
// output layer
NumericVector l2 = sigmoid(dotprodvm(l1, syn1));
// error at output layer
NumericVector l2_err = y - l2;
err_sum += sum(abs(l2_err));
l2_deltas.push_back(l2_err * sigmoid_deriv(l2));
if (verbose) {
Rcout << ">> sample " << smp_i << "pos" << i << "=" << x << std::endl;
Rcout << a[bin_dim - i - 1] << std::endl;
Rcout << l2_err << std::endl;
}
// save output to be displayed
cHat[bin_dim -i -1] = round(l2[0]);
// save hidden layer to be used for next time-step
l1_vals.push_back(clone(l1));
}
// FP ends
if (verbose) {
for (List::iterator li = l1_vals.begin(); li != l1_vals.end(); li ++){
Rcout << "l1 vals: " << as<NumericVector>(*li) << std::endl;
}
}
// BP begins
if (verbose) Rcout << "-- BP" << std::endl;
// layer-1 at "next-time"-step
NumericVector future_l1_delta = rep(NumericVector::create(0), hidden_dim);
for (std::size_t i = 0; i < a.size(); ++i) {
NumericVector x = NumericVector::create(a[bin_dim - i - 1],
b[bin_dim - i - 1]);
NumericVector l1 = l1_vals[l1_vals.size() - i - 1];
NumericVector l1_prev = l1_vals[l1_vals.size() - i -2];
// delta at output layer
NumericVector l2_delta = l2_deltas[l2_deltas.size() -i - 1];
// delta at hidden layer
NumericVector l1_delta = (dotprodvm(l2_delta, dotrans(syn1)) +
dotprodvm(future_l1_delta, dotrans(synh))) *
sigmoid_deriv(l1);
// collect updates untill all time-steps finished
syn1_up += dotprodmm(v2m(l1, l1.size(), 1),
v2m(l2_delta, 1, l2_delta.size()));
synh_up += dotprodmm(v2m(l1_prev, l1_prev.size(), 1),
v2m(l1_delta, 1, l1_delta.size()));
syn0_up += dotprodmm(v2m(x, x.size(), 1),
v2m(l1_delta, 1, l1_delta.size()));
future_l1_delta = l1_delta;
if (verbose){
Rcout << "bp input at pos: " << i << "=" << x << std::endl;
Rcout << "future delta: " << future_l1_delta << std::endl;
}
}
// BP ends here
// Update netowrk parameters
syn0 += (syn0_up * alpha);
syn1 += (syn1_up * alpha);
synh += (synh_up * alpha);
syn0_up = InitialMatrix(input_dim, hidden_dim) * 0;
syn1_up = InitialMatrix(hidden_dim, output_dim) * 0;
synh_up = InitialMatrix(hidden_dim, hidden_dim) * 0;
if (verbose) {
Rcout << "!!After" << std::endl;
Rcout << syn0(0, 0) << syn0(0, 1) << syn0(1, 0) << syn0(1, 1) << std::endl;
Rcout << syn1(0, 0) << syn1(1, 0) << syn1(2, 0) << syn1(0, 3) << std::endl;
Rcout << synh(0, 0) << synh(0, 1) << synh(1, 0) << synh(1, 1) << std::endl;
Rcout << syn0_up(1, 2) << std::endl;
}
if (!verbose && smp_i % 1000 == 0 ){
Rcout << "Sample " << smp_i << std::endl;
Rcout << "Overall Error: " << err_sum << std::endl;
Rcout << "Pred: [" << cHat << "]" << std::endl;
Rcout << "True: " << c.to_string() << std::endl;
double cHatInt = 0.0;
for (int i = 0; i < cHat.size(); i ++){
cHatInt += pow(2.0, cHat.size() - i - 1) * cHat[i];
}
Rcout << "Calc Binary: " << std::endl << a << std::endl << b << std::endl;
Rcout << "Calc Decimal: " << aint << "+" << bint << "=" << cHatInt << std::endl;
Rcout << "----" << std::endl;
}
}
List syn = List::create(_["syn0"] = syn0,
_["syn1"] = syn1,
_["synh"] = synh);
return syn;
}
/*** R
require(Rcpp)
require(RcppArmadillo)
require(RcppEigen)
set.seed(2016)
max_num <- 2^8 -1
n <- 10000
A <- c(sample(1:(max_num/2), n, replace = T), 1)
B <- c(sample(1:(max_num/2), n, replace = T), 1)
vFlag <- F
rnn_fit <- RNN_Train(A, B, vFlag)
*/
@Puriney
Copy link
Author

Puriney commented Feb 23, 2016

Sample 0
Overall Error: 4.44144
Pred: [1 1 0 0 0 0 1 0]
True: 00110001
Calc Binary: 
00010111
00011010
Calc Decimal: 23+26=194
----
Sample 1000
Overall Error: 3.95198
Pred: [1 0 0 0 0 0 0 0]
True: 10101100
Calc Binary: 
01101000
01000100
Calc Decimal: 104+68=128
----
Sample 2000
Overall Error: 4.22964
Pred: [0 1 1 1 1 0 1 0]
True: 10000100
Calc Binary: 
01111101
00000111
Calc Decimal: 125+7=122
----
Sample 3000
Overall Error: 3.38044
Pred: [0 1 1 1 0 0 0 0]
True: 10000000
Calc Binary: 
01001000
00111000
Calc Decimal: 72+56=112
----
Sample 4000
Overall Error: 2.44474
Pred: [0 0 1 1 1 1 0 0]
True: 00110100
Calc Binary: 
00000110
00101110
Calc Decimal: 6+46=60
----
Sample 5000
Overall Error: 1.72278
Pred: [0 1 1 0 0 0 0 1]
True: 01100001
Calc Binary: 
01010010
00001111
Calc Decimal: 82+15=97
----
Sample 6000
Overall Error: 0.613623
Pred: [0 0 1 1 1 1 0 1]
True: 00111101
Calc Binary: 
00100001
00011100
Calc Decimal: 33+28=61
----
Sample 7000
Overall Error: 0.743125
Pred: [1 0 1 1 1 1 0 0]
True: 10111100
Calc Binary: 
01111101
00111111
Calc Decimal: 125+63=188
----
Sample 8000
Overall Error: 0.488955
Pred: [1 0 0 0 0 1 0 0]
True: 10000100
Calc Binary: 
01011011
00101001
Calc Decimal: 91+41=132
----
Sample 9000
Overall Error: 0.355511
Pred: [0 0 1 1 0 0 1 0]
True: 00110010
Calc Binary: 
00001100
00100110
Calc Decimal: 12+38=50
----
Sample 10000
Overall Error: 0.125537
Pred: [0 0 0 0 0 0 1 0]
True: 00000010
Calc Binary: 
00000001
00000001
Calc Decimal: 1+1=2
----

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment