Last active
May 30, 2018 06:26
-
-
Save marty1885/4d9a28e34b5ac1fea998526314511bc6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//tiny-dnn headers | |
#define CNN_USE_AVX | |
#include <tiny_dnn/tiny_dnn.h> | |
using namespace tiny_dnn; | |
using namespace tiny_dnn::activation; | |
using namespace tiny_dnn::layers; | |
//tiny-dnn comes with xtensor, use it | |
#include <tiny_dnn/xtensor/xnpy.hpp> | |
#include <tiny_dnn/xtensor/xview.hpp> | |
#include <tiny_dnn/xtensor/xsort.hpp> | |
#include <iostream> | |
#include <random> | |
#include <tuple> | |
constexpr int NUM_BATCHES = 50; | |
constexpr int BATCH_SIZE = 32; | |
//Convert data to a form tiny-dnn can read | |
decltype(auto) convertDataset(const xt::xarray<float>& X, const xt::xarray<float>& Y) | |
{ | |
size_t numData = X.shape()[0]; | |
std::vector<vec_t> x(numData); | |
std::vector<label_t> y(numData); | |
for(size_t i=0;i<numData;i++) | |
{ | |
auto xb = xt::view(X, i, xt::all()); | |
x[i] = vec_t(xb.begin(), xb.end()); | |
auto yb = xt::view(Y, i, xt::all()); | |
y[i] = xt::argmax(yb)[0]; | |
} | |
//shuffle dataset | |
shuffle(x.begin(), x.end(), std::mt19937(42)); | |
shuffle(y.begin(), y.end(), std::mt19937(42)); | |
return std::tuple(x, y); | |
} | |
int main() | |
{ | |
//Load dataset | |
//X.npy is saved in fp32 | |
xt::xarray<float> X = xt::load_npy<float>("X.npy"); | |
//But Y.npy is saved in fp64, needs to convert | |
xt::xarray<float> Y = xt::cast<float>((xt::xarray<float>)xt::load_npy<double>("Y.npy")); | |
assert(X.shape()[0] == Y.shape()[0]); | |
network<sequential> net; | |
net << conv(64, 64, 5, 1, 6) << leaky_relu() // in: 64x64x1, out 6 chanels, kernel size: 5 | |
<< max_pool(60, 60, 6, 2) // in: 60x60x6, 2x2 pooling | |
<< conv(30, 30, 5, 6, 9) << leaky_relu() // in: 30x30x6, out 9 channels, kernel size: 5 | |
<< max_pool(26, 26, 9, 2) // in:26x26x9, 2x2 pooling | |
<< conv(13, 13, 6, 9, 12) << leaky_relu()// in: 13x13x9, out 12 channels, kernel size: 6 | |
<< fc(8*8*12, 10) | |
<< softmax(); | |
auto [x, y] = convertDataset(X, Y); | |
//Setup callbacks so we can see the traning progress | |
size_t numData = x.size(); | |
progress_display disp(numData); | |
timer t; | |
int epoch = 1; | |
auto onEnumerateEpoch = [&]() | |
{ | |
std::cout << std::endl; | |
std::cout << "Epoch " << epoch << "/" << NUM_BATCHES << " finished. " | |
<< t.elapsed() << "s elapsed." << std::endl; | |
++epoch; | |
disp.restart(numData); | |
t.restart(); | |
}; | |
auto onEnumerateMinibatch = [&](){disp += BATCH_SIZE;}; | |
//Train the model | |
adam optimizer; | |
net.train<mse>(optimizer, x, y, BATCH_SIZE, NUM_BATCHES | |
,onEnumerateMinibatch, onEnumerateEpoch); | |
//Test the netowrk | |
net.test(x, y).print_detail(std::cout); | |
net.save("model.net"); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment