Skip to content

Instantly share code, notes, and snippets.

@marty1885
Created March 5, 2018 01:08
Show Gist options
  • Save marty1885/5ef602ae8b02d3494bcf4f9a56f6cd68 to your computer and use it in GitHub Desktop.
Save marty1885/5ef602ae8b02d3494bcf4f9a56f6cd68 to your computer and use it in GitHub Desktop.
#pragma once
#include <Athena/Athena.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <random>
Mat loadImage(std::string path)
{
Mat image;
image = imread(path);
image.convertTo(image, CV_32FC3);
return image/255.f;
}
Mat getChannel(const Mat& img, intmax_t id)
{
Mat channel[3];
split(img, channel);
return channel[id];
}
At::Tensor mat2Tensor(const Mat& image)
{
std::vector<float> data(image.size().height*image.size().width);
memcpy(&data[0], image.data, data.size()*sizeof(float));
return At::Tensor(data, {1,1,image.size().height,image.size().width});
}
Mat tensor2Mat(const At::Tensor& image)
{
Mat res(image.shape()[2]*image.shape()[1]*image.shape()[0], image.shape()[3], CV_32F);
auto data = image.host();
memcpy(res.data, &data[0], data.size()*sizeof(float));
return res;
}
At::Tensor loadImages(std::string path, At::Shape dims, intmax_t num, int seed=0x1234abcd)
{
std::minstd_rand0 eng(seed);
intmax_t size = dims.volume();
std::vector<float> rawData(num*size);
std::vector<Mat> images(num);
for(int i=0;i<num;i++)
{
Mat img = loadImage(path+std::to_string(i+1)+".bmp");
images[i] = getChannel(img, 0);
}
std::shuffle(images.begin(), images.end(), eng);
for(int i=0;i<num;i++)
{
for(intmax_t j=0;j<size;j++)
rawData[i*size+j] = ((float*)images[i].data)[j];
}
return At::Tensor(rawData, {num, 1, dims[0], dims[1]});
}
decltype(auto) loadDataset()
{
const int NUM_IMAGES = 1400;
std::cout << "Loading input data" << std::endl;
At::Tensor input = loadImages("../dataset/input/", {33,33}, NUM_IMAGES);
std::cout << "Loading output data" << std::endl;
At::Tensor output = loadImages("../dataset/label/", {21,21}, NUM_IMAGES);
return std::make_tuple(input, output);
}
At::SequentialNetwork SRCNN()
{
At::SequentialNetwork net(At::Tensor::defaultBackend());
net << At::Conv2DLayer(1, 64, {9,9}) << At::ReluLayer()
<< At::Conv2DLayer(64, 32, {1,1}) << At::ReluLayer()
<< At::Conv2DLayer(32, 1, {5, 5});
net.compile();
return std::move(net);
}
#include <Athena/Athena.hpp>
#include <Athena/XtensorBackend.hpp>
#include <Athena/NNPACKBackend.hpp>
#include <Athena/Utils/Archive.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
using namespace cv;
#include <tuple>
#include <vector>
#include <chrono>
#include "srcnn.hpp"
using namespace std::chrono;
const float LR_BASE = 0.000065f;
int main()
{
At::XtensorBackend backend;
At::NNPackBackend nnpBackend;
backend.useAlgorithm<At::FCForwardFunction>("fullyconnectedForward", nnpBackend);
backend.useAlgorithm<At::FCBackwardFunction>("fullyconnectedBackward", nnpBackend);
backend.useAlgorithm<At::Conv2DBackward>("conv2DBackward",nnpBackend);
backend.useAlgorithm<At::Conv2DForward>("conv2DForward",nnpBackend);
At::Tensor::setDefaultBackend(&backend);
At::SequentialNetwork net = SRCNN();
net.summary({At::Shape::None, 1, 33, 33});
At::MSELoss loss;
At::AdaGradOptimizer oprimizer;
oprimizer.alpha_ = LR_BASE;
auto [input, output] = loadDataset();
size_t epoch = 20;
size_t batchSize = 16;
int count = 0;
int currentEpoch = 0;
auto onBatch = [&](float l)
{
int sampleNum = input.shape()[0];
count += batchSize;
fputs("\033[2K\r",stdout);
std::cout << l << ", progress = " << count << "/" << sampleNum << std::flush;
};
auto onEpoch = [&](float l)
{
fputs("\033[2K\r",stdout);
currentEpoch++;
std::cout << "Epoch " << currentEpoch << ", Epoch Loss: " << l << std::endl;
count = 0;
};
high_resolution_clock::time_point t1 = high_resolution_clock::now();
net.fit(oprimizer,loss,input,output,batchSize,epoch,onBatch,onEpoch);
high_resolution_clock::time_point t2 = high_resolution_clock::now();
duration<double> time_span = duration_cast<duration<double>>(t2 - t1);
std::cout << "It took me " << time_span.count() << " seconds." << std::endl;
At::save(net.states(), "net.json");
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment