Created
May 29, 2018 00:11
-
-
Save marty1885/ed3838349f20864ffe8c17bc907fb212 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <Athena/Athena.hpp> | |
#include <Athena/Backend/XtensorBackend.hpp> | |
#include <Athena/Backend/NNPACKBackend.hpp> | |
#include "mnist_reader.hpp" | |
#include <iostream> | |
#include <chrono> | |
using namespace std::chrono; | |
int maxElementIndex(const std::vector<float>& vec) | |
{ | |
return std::distance(vec.begin(), std::max_element(vec.begin(), vec.end())); | |
} | |
At::Tensor imagesToTensor(const std::vector<std::vector<uint8_t>>& arr) | |
{ | |
std::vector<float> res; | |
res.reserve(arr.size()*arr[0].size()); | |
for(const auto& img : arr) | |
{ | |
for(const auto v : img) | |
res.push_back(v/255.f); | |
} | |
return At::Tensor(res,{(intmax_t)arr.size(),(intmax_t)arr[0].size()}); | |
} | |
std::vector<float> onehot(int ind, int total) | |
{ | |
std::vector<float> vec(total); | |
for(int i=0;i<total;i++) | |
vec[i] = (i==ind)?1.f:0.f; | |
return vec; | |
} | |
At::Tensor labelsToBinary(const std::vector<uint8_t>& labels) | |
{ | |
std::vector<float> buffer; | |
buffer.reserve(labels.size()*4); | |
for(const auto& label : labels) | |
{ | |
std::vector<float> vec(4); | |
for(int i=0;i<4;i++) | |
vec[i] = (label&(1<<(4-i-1)))!=0; | |
for(const auto v : vec) | |
buffer.push_back((float)v); | |
} | |
return At::Tensor(buffer, {(intmax_t)labels.size(), 4}); | |
} | |
int binaryDecode(const At::Tensor& t) | |
{ | |
auto vec = t.host<float>(); | |
std::string str; | |
for(int i=0;i<4;i++) | |
str += (vec[i]>0.5?"1":"0"); | |
return std::stoi(str.c_str(), nullptr, 2); | |
} | |
int main() | |
{ | |
At::XtensorBackend backend; | |
At::Tensor::setDefaultBackend(&backend); | |
//Use the NNPACK backend to accelerate things. Remove if NNPACK is not avliable | |
At::NNPackBackend nnpBackend; | |
backend.useAllAlgorithm(nnpBackend); | |
At::SequentialNetwork net; | |
auto dataset = mnist::read_dataset<std::vector, std::vector, uint8_t, uint8_t>("../mnist"); | |
At::Tensor traningImage = imagesToTensor(dataset.training_images); | |
At::Tensor traningLabels = labelsToBinary(dataset.training_labels); | |
At::Tensor testingImage = imagesToTensor(dataset.test_images); | |
At::Tensor testingLabels = labelsToBinary(dataset.test_labels); | |
net.add(At::FullyConnectedLayer(784,50)); | |
net.add(At::SigmoidLayer()); | |
net.add(At::FullyConnectedLayer(50,4)); | |
net.add(At::SigmoidLayer()); | |
net.compile(); | |
net.summary({At::Shape::None, 784}); | |
At::AdamOptimizer opt; | |
At::MSELoss loss; | |
size_t epoch = 8; | |
size_t batchSize = 16; | |
int count = 0; | |
auto onBatch = [&](float l) | |
{ | |
int sampleNum = traningImage.shape()[0]; | |
std::cout << "\033[2K\r" | |
<< count << "/" << sampleNum << ", Loss = " << l << std::flush; | |
count += batchSize; | |
}; | |
auto onEpoch = [&](float l) | |
{ | |
std::cout << "\033[2K\r" | |
<< "Epoch Loss: " << l << std::endl; | |
count = 0; | |
}; | |
high_resolution_clock::time_point t1 = high_resolution_clock::now(); | |
net.fit(opt,loss,traningImage,traningLabels,batchSize,epoch,onBatch,onEpoch); | |
high_resolution_clock::time_point t2 = high_resolution_clock::now(); | |
duration<double> time_span = duration_cast<duration<double>>(t2 - t1); | |
std::cout << "It took me " << time_span.count() << " seconds." << std::endl; | |
int correct = 0; | |
for(intmax_t i=0;i<testingImage.shape()[0];i++) | |
{ | |
At::Tensor x = testingImage.chunk({i},{1}); | |
At::Tensor res = net.predict(x); | |
int predictLabel = binaryDecode(res); | |
if(predictLabel == dataset.test_labels[i]) | |
correct++; | |
} | |
std::cout << "Accuracy: " << correct/(float)testingImage.shape()[0] << std::endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment