Skip to content

Instantly share code, notes, and snippets.

@marty1885
Last active July 6, 2017 16:37
Show Gist options
  • Save marty1885/15cf347d007b6fa60b91b06a9cc3ffeb to your computer and use it in GitHub Desktop.
Save marty1885/15cf347d007b6fa60b91b06a9cc3ffeb to your computer and use it in GitHub Desktop.
#include <iostream>
#include <vector>
#include <string>
#include <random>
#include <assert.h>
#define STB_IMAGE_IMPLEMENTATION
#include "3rdparty/stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "3rdparty/stb_image_write.h"
#define STB_IMAGE_RESIZE_IMPLEMENTATION
#include "3rdparty/stb_image_resize.h"
#define CNN_USE_AVX
#define CNN_USE_TBB
#include "tiny_dnn/tiny_dnn.h"
using namespace tiny_dnn;
using namespace tiny_dnn::activation;
using namespace tiny_dnn::layers;
using namespace std;
struct float3
{
float3() : float3(0,0,0)
{
}
float3(const float3& other) : float3(other.x, other.y, other.z)
{
}
float3(float _x,float _y,float _z)
{
x = _x;
y = _y;
z = _z;
}
inline float& operator[] (int index)
{
return data[index];
}
union
{
float data[3];
struct {float x,y,z;};
};
};
inline float clamp(float val, float minVal, float maxVal)
{
if(val > maxVal)
return maxVal;
if(val < minVal)
return minVal;
return val;
}
class Image
{
public:
bool load(const std::string& path, float gamma = 2.2)
{
//Clear any previous data. Although in any good programming. This shouldn't happen
int comp = 0;
float* imageData = stbi_loadf(path.c_str(), &width, &height, &comp, 4);
if(imageData == nullptr)
{
std::cout << "Filed to load image " << path << std::endl;
return false;
}
createEmpty(width, height);
for(int i=0;i<height;i++)
{
for(int j=0;j<width;j++)
{
float3 pixVal = float3(imageData[4*(i*width+j)+0],imageData[4*(i*width+j)+1]
,imageData[4*(i*width+j)+2]);
//Flip the image verticaly. Acculding to the OpenGL convention, the origin is at bottom-left.
//But STBI assumes the origin is at the top-left.
// /2.2 because stbi_loadf by default turns everything into linear rgb.
pixVal.x = powf(pixVal.x, gamma/2.2);
pixVal.y = powf(pixVal.y, gamma/2.2);
pixVal.z = powf(pixVal.z, gamma/2.2);
buffer[(height-i-1)*width+j] = pixVal;
}
}
stbi_image_free(imageData);
return true;
}
bool good() const
{
return buffer.size() != 0;
}
float3 getPixel(int x, int y) const
{
return buffer[y*width+x];
}
void setPixel(int x, int y, float3 val)
{
buffer[y*width+x] = val;
}
void createEmpty(int w, int h)
{
width = w;
height = h;
buffer.resize(width*height);
}
Image subImage(int x, int y, int w, int h)
{
Image img;
img.createEmpty(w,h);
for(int i=0;i<w;i++)
{
for(int j=0;j<h;j++)
{
float3 pixVal = getPixel(j+x,i+y);
img.setPixel(j,i, pixVal);
}
}
return img;
}
Image flipImage(bool flipX, bool flipY)
{
Image img;
img.createEmpty(width, height);
for(int i=0;i<width;i++)
{
int yCoord = flipY ? (height-1-i) : i;
for(int j=0;j<height;j++)
{
float3 pixVal = getPixel(j,i);
int xCoord = flipX ? (width-1-j) : j;
img.setPixel(xCoord,yCoord, pixVal);
}
}
return img;
}
Image resize(int w, int h)
{
Image img;
img.createEmpty(w,h);
stbir_resize_float((float*)&buffer[0], width, height,0,
(float*)&img.buffer[0],w,h,0
,3);
return img;
}
bool save(std::string path, float gamma=2.2)
{
if(!good())
{
std::cout << "No image is in the buffer! can't save image." << std::endl;
return false;
}
std::string fileExt = path.substr(path.find_last_of(".") + 1);
if(fileExt != "png" && fileExt != "bmp" && fileExt != "tga")
{
std::cout << "Error : Image format *" << fileExt << "* not supported." << std::endl << std::endl;
return false;
}
const float inversedGamma = 1.f/gamma;
unsigned char* imageData = new unsigned char[width*height*3];
for(int i=0;i<height;i++)
{
for(int j=0;j<width;j++)
{
int index = i*width+j;
float3 pixVal = getPixel(j,i);
pixVal.x = powf(pixVal.x, inversedGamma)*255.0f;
pixVal.y = powf(pixVal.y, inversedGamma)*255.0f;
pixVal.z = powf(pixVal.z, inversedGamma)*255.0f;
pixVal.x = clamp(pixVal.x,0,255);
pixVal.y = clamp(pixVal.y,0,255);
pixVal.z = clamp(pixVal.z,0,255);
for(int k=0;k<3;k++)
imageData[((height-i-1)*width+j)*3+k] = pixVal[k];
}
}
int success = 0;
if(fileExt == "png")
success = stbi_write_png(path.c_str(), width, height, 3, imageData, width*3);
else if(fileExt == "bmp")
success = stbi_write_bmp(path.c_str(), width, height, 3, imageData);
else if(fileExt == "tga")
success = stbi_write_tga(path.c_str(), width, height, 3, imageData);
delete [] imageData;
if(success == 0)
{
std::cout << "Error : failed to save image as *" << path << "*." << std::endl << std::endl;
return false;
}
return true;
}
std::vector<float3> buffer;
int width;
int height;
};
inline std::string leftpad(const std::string& str, int length)
{
auto val = str;
if(str.size() < length)
{
int diff = length - str.size();
for(int i=0;i<diff;i++)
val = "0" + val;
}
return val;
}
inline int randInt(int minVal, int maxVal)
{
static std::mt19937 engine(time(0));
std::uniform_int_distribution<int> dist(minVal, maxVal);
return dist(engine);
}
inline vec_t fromImage(Image& image)
{
int size = image.buffer.size();
vec_t vec(size*3);
for(int i=0;i<size;i++)
vec[i] = image.buffer[i][0];
for(int i=0;i<size;i++)
vec[size+i] = image.buffer[i][1];
for(int i=0;i<size;i++)
vec[size*2+i] = image.buffer[i][2];
return vec;
}
inline Image fromVector(const vec_t& vec, int w, int h)
{
Image img;
int size = w*h;
img.createEmpty(w,h);
for(int i=0;i<size;i++)
img.buffer[i][0] = vec[i];
for(int i=0;i<size;i++)
img.buffer[i][1] = vec[size+i];
for(int i=0;i<size;i++)
img.buffer[i][2] = vec[size*2+i];
return img;
}
int main()
{
const constexpr int epochSampleSize = 8;
const constexpr int batchSize = 8;
const constexpr int imageReuseMul = 1;
const constexpr int epochNum = 100;
const constexpr int trainNum = 100;
std::vector<Image> input(epochSampleSize);
std::vector<Image> output(epochSampleSize);
int num = 0;
const constexpr int outputSize = 398;
network<sequential> net;
net << conv(142,142,3,3,32) << leaky_relu_layer()
<< conv(140,140,3,32,64) << leaky_relu_layer()
<< conv(138,138,3,64,64) << leaky_relu_layer()
<< conv(136,136,3,64,128) << leaky_relu_layer()
<< conv(134,134,3,128,128) << leaky_relu_layer();
int windowSize = 3;
int stride = 3;
net << deconvolutional_layer(132,132,windowSize,128,3
,padding::valid, true, stride, stride);
cout << net[10]->out_data_size() << endl;
adam optimizer;
static_assert(epochSampleSize % imageReuseMul == 0);
std::mt19937 engine(time(0));
timer t;
progress_display disp(epochSampleSize*epochNum);
for(int currentTrain = 0; currentTrain < trainNum;currentTrain++)
{
//Load data
for(int i=0;i<epochSampleSize;i+=imageReuseMul)
{
std::string name = "dataset/" + std::to_string(randInt(0,79)) + ".jpg";
Image img;
img.load(name);
for(int j=0;j<imageReuseMul;j++)
{
output[i+j] = img.subImage(randInt(0, img.width-1-outputSize),
randInt(0, img.height-1-outputSize),outputSize,outputSize);
output[i+j] = output[i+j].flipImage(randInt(0,1),randInt(0,1));
input[i+j] = output[i+j].resize(142,142);
}
}
std::vector<int> indices(epochSampleSize);
for(int i=0;i<epochSampleSize;i++)
indices[i] = i;
std::shuffle(indices.begin(), indices.end(), engine);
std::vector<vec_t> inputVectors(epochSampleSize);
std::vector<vec_t> outputVectos(epochSampleSize);
for(int j=0;j<epochSampleSize;j++)
inputVectors[j] = fromImage(input[indices[j]]);
for(int j=0;j<epochSampleSize;j++)
outputVectos[j] = fromImage(output[indices[j]]);
auto onEpoch = [&]()
{
};
auto onMinibatch = [&]()
{
disp += batchSize;
};
net.fit<mse>(optimizer, inputVectors, outputVectos, batchSize, epochNum
, onMinibatch, onEpoch);
std::cout << "Epoch " << currentTrain << "/" << trainNum << " finished. "
<< t.elapsed() << "s elapsed." << std::endl;
disp.restart(epochSampleSize*epochNum);
t.restart();
net.save("save/mikunet" + std::to_string(currentTrain+1));
for(int i=0;i<8;i++)
{
//test the network
Image img2 = fromVector(outputVectos[i],outputSize,outputSize);
img2.save("testresult/" + std::to_string(currentTrain) + "_" + std::to_string(i) + "_origin.png");
auto res = net.predict(inputVectors[i]);
Image img = fromVector(res,outputSize,outputSize);
img.save("testresult/" + std::to_string(currentTrain) + "_" + std::to_string(i) + "_large.png");
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment