Created
February 12, 2018 09:02
-
-
Save marty1885/879e5d18c3c2f3b92004561812f780c6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "csv.h" | |
#include <Athena/Athena.hpp> | |
#include <Athena/XtensorBackend.hpp> | |
#include <iostream> | |
#include <string> | |
#include <vector> | |
#include <random> | |
#include <chrono> | |
using namespace std; | |
//Returns a onthot vector | |
vector<float> onehot(int n) | |
{ | |
vector<float> v(3, 0); | |
v[n] = 1; | |
return v; | |
} | |
//Finds the index of the max element from a given vector. The inverse of onehot | |
int reverseOnehot(const std::vector<float>& vec) | |
{ | |
return std::distance(vec.begin(), std::max_element(vec.begin(), vec.end())); | |
} | |
int main() | |
{ | |
//Load the dataset | |
io::CSVReader<5> in("iris.data"); | |
in.read_header(io::ignore_extra_column | |
,"sepal_length","sepal_width","petal_length","petal_width","species"); | |
float sepalLength, sepalWidth, peralLength, petalWidth; | |
std::string species; | |
vector<vector<float>> input; | |
vector<vector<float>> output; | |
while(in.read_row(sepalLength, sepalWidth, peralLength, petalWidth, species)) | |
{ | |
input.push_back({sepalLength, sepalWidth, peralLength, petalWidth}); | |
//Convert species into indices | |
int id = 0; | |
if(species == "Iris-setosa") | |
id = 0; | |
else if(species == "Iris-versicolor") | |
id = 1; | |
else if(species == "Iris-virginica") | |
id = 2; | |
else | |
cerr << "Dataset error: species " << species << "should not exist" << endl; | |
//Then convert the index into one-hot vector | |
output.push_back(onehot(id)); | |
} | |
//Shuffle the dataset. | |
//Note that because we're using the same seed for both shuffle, | |
//the data is shuffled in the same order | |
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); | |
shuffle(input.begin(), input.end(), std::default_random_engine(seed)); | |
shuffle(output.begin(), output.end(), std::default_random_engine(seed)); | |
//Initialize a backend for the network to work upon | |
At::XtensorBackend backend; | |
At::Tensor::setDefaultBackend(&backend); | |
//Create a model | |
At::SequentialNetwork net(&backend); | |
net << At::FullyConnectedLayer(4, 16) | |
<< At::TanhLayer() | |
<< At::FullyConnectedLayer(16, 3); | |
net.compile(); | |
//Convert nested vectors into At::Tensor | |
At::Tensor inputTensor = input; | |
At::Tensor outputTensor = output; | |
//The callbacks. Prints loss each epoch | |
auto onBatch = [&](float l){}; | |
auto onEpoch = [&](float l){std::cout << "Epoch Loss: " << l << "\r";}; | |
//Train the model | |
At::L1Loss loss; | |
At::RMSPropOptimizer optimizer; | |
intmax_t batchSize = 2; | |
intmax_t epochs = 150; | |
net.fit(optimizer, loss, inputTensor, outputTensor, batchSize, epochs | |
, onBatch, onEpoch); | |
cout << "\n"; | |
//Test our classifer | |
int correct = 0; | |
for(int i=0;i<input.size();i++) | |
{ | |
At::Tensor in = input[i]; | |
At::Tensor out = net.predict(in); | |
//cout << out << " " << correctSol << endl; | |
//Tensor::host() copies whaterver is in the tensor | |
// into a std::vector<float> flattened | |
int correctSol = reverseOnehot(output[i]); | |
int prediction = reverseOnehot(out.host()); | |
if(correctSol == prediction) | |
correct++; | |
} | |
cout << "Accuracy: " << (float)correct/input.size()*100.f << "%" << endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment