Last active
July 6, 2017 16:37
-
-
Save marty1885/15cf347d007b6fa60b91b06a9cc3ffeb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <vector> | |
#include <string> | |
#include <random> | |
#include <assert.h> | |
#define STB_IMAGE_IMPLEMENTATION | |
#include "3rdparty/stb_image.h" | |
#define STB_IMAGE_WRITE_IMPLEMENTATION | |
#include "3rdparty/stb_image_write.h" | |
#define STB_IMAGE_RESIZE_IMPLEMENTATION | |
#include "3rdparty/stb_image_resize.h" | |
#define CNN_USE_AVX | |
#define CNN_USE_TBB | |
#include "tiny_dnn/tiny_dnn.h" | |
using namespace tiny_dnn; | |
using namespace tiny_dnn::activation; | |
using namespace tiny_dnn::layers; | |
using namespace std; | |
struct float3 | |
{ | |
float3() : float3(0,0,0) | |
{ | |
} | |
float3(const float3& other) : float3(other.x, other.y, other.z) | |
{ | |
} | |
float3(float _x,float _y,float _z) | |
{ | |
x = _x; | |
y = _y; | |
z = _z; | |
} | |
inline float& operator[] (int index) | |
{ | |
return data[index]; | |
} | |
union | |
{ | |
float data[3]; | |
struct {float x,y,z;}; | |
}; | |
}; | |
inline float clamp(float val, float minVal, float maxVal) | |
{ | |
if(val > maxVal) | |
return maxVal; | |
if(val < minVal) | |
return minVal; | |
return val; | |
} | |
class Image | |
{ | |
public: | |
bool load(const std::string& path, float gamma = 2.2) | |
{ | |
//Clear any previous data. Although in any good programming. This shouldn't happen | |
int comp = 0; | |
float* imageData = stbi_loadf(path.c_str(), &width, &height, &comp, 4); | |
if(imageData == nullptr) | |
{ | |
std::cout << "Filed to load image " << path << std::endl; | |
return false; | |
} | |
createEmpty(width, height); | |
for(int i=0;i<height;i++) | |
{ | |
for(int j=0;j<width;j++) | |
{ | |
float3 pixVal = float3(imageData[4*(i*width+j)+0],imageData[4*(i*width+j)+1] | |
,imageData[4*(i*width+j)+2]); | |
//Flip the image verticaly. Acculding to the OpenGL convention, the origin is at bottom-left. | |
//But STBI assumes the origin is at the top-left. | |
// /2.2 because stbi_loadf by default turns everything into linear rgb. | |
pixVal.x = powf(pixVal.x, gamma/2.2); | |
pixVal.y = powf(pixVal.y, gamma/2.2); | |
pixVal.z = powf(pixVal.z, gamma/2.2); | |
buffer[(height-i-1)*width+j] = pixVal; | |
} | |
} | |
stbi_image_free(imageData); | |
return true; | |
} | |
bool good() const | |
{ | |
return buffer.size() != 0; | |
} | |
float3 getPixel(int x, int y) const | |
{ | |
return buffer[y*width+x]; | |
} | |
void setPixel(int x, int y, float3 val) | |
{ | |
buffer[y*width+x] = val; | |
} | |
void createEmpty(int w, int h) | |
{ | |
width = w; | |
height = h; | |
buffer.resize(width*height); | |
} | |
Image subImage(int x, int y, int w, int h) | |
{ | |
Image img; | |
img.createEmpty(w,h); | |
for(int i=0;i<w;i++) | |
{ | |
for(int j=0;j<h;j++) | |
{ | |
float3 pixVal = getPixel(j+x,i+y); | |
img.setPixel(j,i, pixVal); | |
} | |
} | |
return img; | |
} | |
Image flipImage(bool flipX, bool flipY) | |
{ | |
Image img; | |
img.createEmpty(width, height); | |
for(int i=0;i<width;i++) | |
{ | |
int yCoord = flipY ? (height-1-i) : i; | |
for(int j=0;j<height;j++) | |
{ | |
float3 pixVal = getPixel(j,i); | |
int xCoord = flipX ? (width-1-j) : j; | |
img.setPixel(xCoord,yCoord, pixVal); | |
} | |
} | |
return img; | |
} | |
Image resize(int w, int h) | |
{ | |
Image img; | |
img.createEmpty(w,h); | |
stbir_resize_float((float*)&buffer[0], width, height,0, | |
(float*)&img.buffer[0],w,h,0 | |
,3); | |
return img; | |
} | |
bool save(std::string path, float gamma=2.2) | |
{ | |
if(!good()) | |
{ | |
std::cout << "No image is in the buffer! can't save image." << std::endl; | |
return false; | |
} | |
std::string fileExt = path.substr(path.find_last_of(".") + 1); | |
if(fileExt != "png" && fileExt != "bmp" && fileExt != "tga") | |
{ | |
std::cout << "Error : Image format *" << fileExt << "* not supported." << std::endl << std::endl; | |
return false; | |
} | |
const float inversedGamma = 1.f/gamma; | |
unsigned char* imageData = new unsigned char[width*height*3]; | |
for(int i=0;i<height;i++) | |
{ | |
for(int j=0;j<width;j++) | |
{ | |
int index = i*width+j; | |
float3 pixVal = getPixel(j,i); | |
pixVal.x = powf(pixVal.x, inversedGamma)*255.0f; | |
pixVal.y = powf(pixVal.y, inversedGamma)*255.0f; | |
pixVal.z = powf(pixVal.z, inversedGamma)*255.0f; | |
pixVal.x = clamp(pixVal.x,0,255); | |
pixVal.y = clamp(pixVal.y,0,255); | |
pixVal.z = clamp(pixVal.z,0,255); | |
for(int k=0;k<3;k++) | |
imageData[((height-i-1)*width+j)*3+k] = pixVal[k]; | |
} | |
} | |
int success = 0; | |
if(fileExt == "png") | |
success = stbi_write_png(path.c_str(), width, height, 3, imageData, width*3); | |
else if(fileExt == "bmp") | |
success = stbi_write_bmp(path.c_str(), width, height, 3, imageData); | |
else if(fileExt == "tga") | |
success = stbi_write_tga(path.c_str(), width, height, 3, imageData); | |
delete [] imageData; | |
if(success == 0) | |
{ | |
std::cout << "Error : failed to save image as *" << path << "*." << std::endl << std::endl; | |
return false; | |
} | |
return true; | |
} | |
std::vector<float3> buffer; | |
int width; | |
int height; | |
}; | |
inline std::string leftpad(const std::string& str, int length) | |
{ | |
auto val = str; | |
if(str.size() < length) | |
{ | |
int diff = length - str.size(); | |
for(int i=0;i<diff;i++) | |
val = "0" + val; | |
} | |
return val; | |
} | |
inline int randInt(int minVal, int maxVal) | |
{ | |
static std::mt19937 engine(time(0)); | |
std::uniform_int_distribution<int> dist(minVal, maxVal); | |
return dist(engine); | |
} | |
inline vec_t fromImage(Image& image) | |
{ | |
int size = image.buffer.size(); | |
vec_t vec(size*3); | |
for(int i=0;i<size;i++) | |
vec[i] = image.buffer[i][0]; | |
for(int i=0;i<size;i++) | |
vec[size+i] = image.buffer[i][1]; | |
for(int i=0;i<size;i++) | |
vec[size*2+i] = image.buffer[i][2]; | |
return vec; | |
} | |
inline Image fromVector(const vec_t& vec, int w, int h) | |
{ | |
Image img; | |
int size = w*h; | |
img.createEmpty(w,h); | |
for(int i=0;i<size;i++) | |
img.buffer[i][0] = vec[i]; | |
for(int i=0;i<size;i++) | |
img.buffer[i][1] = vec[size+i]; | |
for(int i=0;i<size;i++) | |
img.buffer[i][2] = vec[size*2+i]; | |
return img; | |
} | |
int main() | |
{ | |
const constexpr int epochSampleSize = 8; | |
const constexpr int batchSize = 8; | |
const constexpr int imageReuseMul = 1; | |
const constexpr int epochNum = 100; | |
const constexpr int trainNum = 100; | |
std::vector<Image> input(epochSampleSize); | |
std::vector<Image> output(epochSampleSize); | |
int num = 0; | |
const constexpr int outputSize = 398; | |
network<sequential> net; | |
net << conv(142,142,3,3,32) << leaky_relu_layer() | |
<< conv(140,140,3,32,64) << leaky_relu_layer() | |
<< conv(138,138,3,64,64) << leaky_relu_layer() | |
<< conv(136,136,3,64,128) << leaky_relu_layer() | |
<< conv(134,134,3,128,128) << leaky_relu_layer(); | |
int windowSize = 3; | |
int stride = 3; | |
net << deconvolutional_layer(132,132,windowSize,128,3 | |
,padding::valid, true, stride, stride); | |
cout << net[10]->out_data_size() << endl; | |
adam optimizer; | |
static_assert(epochSampleSize % imageReuseMul == 0); | |
std::mt19937 engine(time(0)); | |
timer t; | |
progress_display disp(epochSampleSize*epochNum); | |
for(int currentTrain = 0; currentTrain < trainNum;currentTrain++) | |
{ | |
//Load data | |
for(int i=0;i<epochSampleSize;i+=imageReuseMul) | |
{ | |
std::string name = "dataset/" + std::to_string(randInt(0,79)) + ".jpg"; | |
Image img; | |
img.load(name); | |
for(int j=0;j<imageReuseMul;j++) | |
{ | |
output[i+j] = img.subImage(randInt(0, img.width-1-outputSize), | |
randInt(0, img.height-1-outputSize),outputSize,outputSize); | |
output[i+j] = output[i+j].flipImage(randInt(0,1),randInt(0,1)); | |
input[i+j] = output[i+j].resize(142,142); | |
} | |
} | |
std::vector<int> indices(epochSampleSize); | |
for(int i=0;i<epochSampleSize;i++) | |
indices[i] = i; | |
std::shuffle(indices.begin(), indices.end(), engine); | |
std::vector<vec_t> inputVectors(epochSampleSize); | |
std::vector<vec_t> outputVectos(epochSampleSize); | |
for(int j=0;j<epochSampleSize;j++) | |
inputVectors[j] = fromImage(input[indices[j]]); | |
for(int j=0;j<epochSampleSize;j++) | |
outputVectos[j] = fromImage(output[indices[j]]); | |
auto onEpoch = [&]() | |
{ | |
}; | |
auto onMinibatch = [&]() | |
{ | |
disp += batchSize; | |
}; | |
net.fit<mse>(optimizer, inputVectors, outputVectos, batchSize, epochNum | |
, onMinibatch, onEpoch); | |
std::cout << "Epoch " << currentTrain << "/" << trainNum << " finished. " | |
<< t.elapsed() << "s elapsed." << std::endl; | |
disp.restart(epochSampleSize*epochNum); | |
t.restart(); | |
net.save("save/mikunet" + std::to_string(currentTrain+1)); | |
for(int i=0;i<8;i++) | |
{ | |
//test the network | |
Image img2 = fromVector(outputVectos[i],outputSize,outputSize); | |
img2.save("testresult/" + std::to_string(currentTrain) + "_" + std::to_string(i) + "_origin.png"); | |
auto res = net.predict(inputVectors[i]); | |
Image img = fromVector(res,outputSize,outputSize); | |
img.save("testresult/" + std::to_string(currentTrain) + "_" + std::to_string(i) + "_large.png"); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment