Created
July 1, 2019 10:56
-
-
Save marty1885/61eecaf605e258568bd28b0dc1d86a48 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//Released under BSD-3 Cluse licsence | |
//How to run: root -b stackedsp.cpp | |
//or: c++ stackedsp.cpp -o stackedsp -O3 -lEtaler `root-config --cflags --ldflags --glibs` && ./stackedsp | |
//Assuming you have Etaler installed in /use/local/lib. Change this if you have it in other places | |
#pragma cling load("/usr/local/lib/libEtaler.so") | |
#include <Etaler/Etaler.hpp> | |
#include <Etaler/Algorithms/SpatialPooler.hpp> | |
#include <Etaler/Encoders/GridCell1d.hpp> | |
#include <Etaler/Encoders/Scalar.hpp> | |
using namespace et; | |
#include <random> | |
constexpr intmax_t INPUT_SDR_SIZE = 2048; | |
#include <TGraph.h> | |
#include <TCanvas.h> | |
#include <TAxis.h> | |
void run_experiment(int num_sps) | |
{ | |
//Make a stack of N SpatialPoolers | |
std::vector<SpatialPooler> sps; | |
for(int i=0;i<num_sps;i++) { | |
SpatialPooler sp(Shape({INPUT_SDR_SIZE}), Shape({INPUT_SDR_SIZE})); | |
sp.setBoostingFactor(10); | |
sp.setGlobalDensity(0.04); | |
sps.push_back(sp); | |
} | |
//Encode a scalar using a 1D GridCell encoder. A ScalarEncoder does the same job. | |
auto encode = [](float v){return encoder::gridCell1d(v,128);}; | |
//Like an RBM, inference on each layer, train it and use it's outout as the | |
//next layer's input | |
auto train = [&](auto x){for(auto& sp : sps) {auto y = sp.compute(x); sp.learn(x, y); x = y;}}; | |
//Train the SP with random inputs | |
std::mt19937 rng; | |
std::uniform_real_distribution<float> dist; | |
for(size_t i=0;i<1000;i++) { | |
Tensor x = encode(dist(rng)); | |
train(x); | |
} | |
auto inference = [&](auto sdr) {Tensor x = sdr; for(auto& sp : sps) x = sp.compute(x); return x;}; | |
//Calculathe the amount of overlapped bits from values 0~1 compared to 0.5 | |
Tensor t = inference(encode(0.5)); | |
std::vector<float> y; | |
std::vector<float> x; | |
for(float i=0;i<1;i+=0.01) { | |
int v = sum(t && inference(encode(i))).toHost<int>()[0]; | |
y.push_back(v); | |
x.push_back(i); | |
} | |
//Now we are done with traning and generating the results. We plot the result | |
//Plot the resulting overlapped bits using ROOT's ploting capablity | |
auto c1 = std::make_unique<TCanvas>(); | |
auto gr1 = std::make_unique<TGraph>((Int_t)x.size(), x.data(), y.data()); | |
//Set the textes | |
gr1->SetTitle(("Overlap cells with " + std::to_string(num_sps) + " SPs").c_str()); | |
gr1->GetXaxis()->SetTitle("Encoded value"); | |
gr1->GetYaxis()->SetTitle("Overlapped bits"); | |
gr1->Draw(); | |
//Save | |
c1->SaveAs(("plot_" + std::to_string(num_sps) + ".png").c_str()); | |
c1->Close(); | |
} | |
void stackedsp() | |
{ | |
//C++17 type reduction is awesome | |
std::vector num_sps = {0, 1, 2, 4, 8, 16}; | |
for(auto n : num_sps) | |
run_experiment(n); | |
} | |
int main() | |
{ | |
stackedsp(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment