Created
March 5, 2018 01:08
-
-
Save marty1885/5ef602ae8b02d3494bcf4f9a56f6cd68 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <Athena/Athena.hpp> | |
#include <opencv2/core/core.hpp> | |
#include <opencv2/highgui/highgui.hpp> | |
#include <opencv2/imgproc/imgproc.hpp> | |
#include <random> | |
Mat loadImage(std::string path) | |
{ | |
Mat image; | |
image = imread(path); | |
image.convertTo(image, CV_32FC3); | |
return image/255.f; | |
} | |
Mat getChannel(const Mat& img, intmax_t id) | |
{ | |
Mat channel[3]; | |
split(img, channel); | |
return channel[id]; | |
} | |
At::Tensor mat2Tensor(const Mat& image) | |
{ | |
std::vector<float> data(image.size().height*image.size().width); | |
memcpy(&data[0], image.data, data.size()*sizeof(float)); | |
return At::Tensor(data, {1,1,image.size().height,image.size().width}); | |
} | |
Mat tensor2Mat(const At::Tensor& image) | |
{ | |
Mat res(image.shape()[2]*image.shape()[1]*image.shape()[0], image.shape()[3], CV_32F); | |
auto data = image.host(); | |
memcpy(res.data, &data[0], data.size()*sizeof(float)); | |
return res; | |
} | |
At::Tensor loadImages(std::string path, At::Shape dims, intmax_t num, int seed=0x1234abcd) | |
{ | |
std::minstd_rand0 eng(seed); | |
intmax_t size = dims.volume(); | |
std::vector<float> rawData(num*size); | |
std::vector<Mat> images(num); | |
for(int i=0;i<num;i++) | |
{ | |
Mat img = loadImage(path+std::to_string(i+1)+".bmp"); | |
images[i] = getChannel(img, 0); | |
} | |
std::shuffle(images.begin(), images.end(), eng); | |
for(int i=0;i<num;i++) | |
{ | |
for(intmax_t j=0;j<size;j++) | |
rawData[i*size+j] = ((float*)images[i].data)[j]; | |
} | |
return At::Tensor(rawData, {num, 1, dims[0], dims[1]}); | |
} | |
decltype(auto) loadDataset() | |
{ | |
const int NUM_IMAGES = 1400; | |
std::cout << "Loading input data" << std::endl; | |
At::Tensor input = loadImages("../dataset/input/", {33,33}, NUM_IMAGES); | |
std::cout << "Loading output data" << std::endl; | |
At::Tensor output = loadImages("../dataset/label/", {21,21}, NUM_IMAGES); | |
return std::make_tuple(input, output); | |
} | |
At::SequentialNetwork SRCNN() | |
{ | |
At::SequentialNetwork net(At::Tensor::defaultBackend()); | |
net << At::Conv2DLayer(1, 64, {9,9}) << At::ReluLayer() | |
<< At::Conv2DLayer(64, 32, {1,1}) << At::ReluLayer() | |
<< At::Conv2DLayer(32, 1, {5, 5}); | |
net.compile(); | |
return std::move(net); | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <Athena/Athena.hpp> | |
#include <Athena/XtensorBackend.hpp> | |
#include <Athena/NNPACKBackend.hpp> | |
#include <Athena/Utils/Archive.hpp> | |
#include <opencv2/core/core.hpp> | |
#include <opencv2/highgui/highgui.hpp> | |
#include <opencv2/imgproc/imgproc.hpp> | |
using namespace cv; | |
#include <tuple> | |
#include <vector> | |
#include <chrono> | |
#include "srcnn.hpp" | |
using namespace std::chrono; | |
const float LR_BASE = 0.000065f; | |
int main() | |
{ | |
At::XtensorBackend backend; | |
At::NNPackBackend nnpBackend; | |
backend.useAlgorithm<At::FCForwardFunction>("fullyconnectedForward", nnpBackend); | |
backend.useAlgorithm<At::FCBackwardFunction>("fullyconnectedBackward", nnpBackend); | |
backend.useAlgorithm<At::Conv2DBackward>("conv2DBackward",nnpBackend); | |
backend.useAlgorithm<At::Conv2DForward>("conv2DForward",nnpBackend); | |
At::Tensor::setDefaultBackend(&backend); | |
At::SequentialNetwork net = SRCNN(); | |
net.summary({At::Shape::None, 1, 33, 33}); | |
At::MSELoss loss; | |
At::AdaGradOptimizer oprimizer; | |
oprimizer.alpha_ = LR_BASE; | |
auto [input, output] = loadDataset(); | |
size_t epoch = 20; | |
size_t batchSize = 16; | |
int count = 0; | |
int currentEpoch = 0; | |
auto onBatch = [&](float l) | |
{ | |
int sampleNum = input.shape()[0]; | |
count += batchSize; | |
fputs("\033[2K\r",stdout); | |
std::cout << l << ", progress = " << count << "/" << sampleNum << std::flush; | |
}; | |
auto onEpoch = [&](float l) | |
{ | |
fputs("\033[2K\r",stdout); | |
currentEpoch++; | |
std::cout << "Epoch " << currentEpoch << ", Epoch Loss: " << l << std::endl; | |
count = 0; | |
}; | |
high_resolution_clock::time_point t1 = high_resolution_clock::now(); | |
net.fit(oprimizer,loss,input,output,batchSize,epoch,onBatch,onEpoch); | |
high_resolution_clock::time_point t2 = high_resolution_clock::now(); | |
duration<double> time_span = duration_cast<duration<double>>(t2 - t1); | |
std::cout << "It took me " << time_span.count() << " seconds." << std::endl; | |
At::save(net.states(), "net.json"); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment