Skip to content

Instantly share code, notes, and snippets.

@liberize
Created November 18, 2013 12:18
Show Gist options
  • Save liberize/7526893 to your computer and use it in GitHub Desktop.
Save liberize/7526893 to your computer and use it in GitHub Desktop.
OpenCV: Sudoku Number Training
#include <fstream>
#include "opencv2/opencv.hpp"
#include <vector>
#include <io.h>
using namespace std;
using namespace cv;
// #define ON_STUDY
const int size = 16;
const int featureLen = size*size;
const string fileDir = "../samples/";
const string labelFileName = "../label.txt";
const string svmDataFile = "../SVM_DATA.xml";
class NumTrainData
{
public:
NumTrainData()
{
memset(data, 0, sizeof(data));
result = -1;
}
public:
float data[size*size];
int result;
};
vector<NumTrainData> buffer;
int ReadTrainData(int maxCount = INT_MAX)
{
// Open image and label file
ifstream lab_ifs(labelFileName);
if(!lab_ifs) return -1;
int sn, lab;
NumTrainData rtd;
int total = 0;
while (lab_ifs >> sn >> lab, !lab_ifs.eof()) {
if(++total > maxCount) break;
cout << total << endl;
if(lab_ifs.fail()) {
cout << "failed to read data from file." << endl;
lab_ifs.close();
return -1;
}
if(lab < 1 || lab > 9) {
cout << "label is invalid. skipped." << endl;
continue;
}
ostringstream oss;
oss << fileDir << sn << ".png";
Mat temp = imread(oss.str(), 0);
if(temp.empty()) {
cout << "failed to read sample #" << sn << ". skipped." << endl;
continue;
}
rtd.result = lab;
for(int i = 0; i < size; i++)
{
for(int j = 0; j < size; j++)
{
rtd.data[i*size+j] = temp.at<uchar>(i, j);
}
}
buffer.push_back(rtd);
}
lab_ifs.close();
return 0;
}
void newSvmStudy(vector<NumTrainData>& trainData)
{
int testCount = trainData.size();
Mat m = Mat::zeros(1, featureLen, CV_32FC1);
Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
Mat res = Mat::zeros(testCount, 1, CV_32SC1);
for (int i= 0; i< testCount; i++)
{
NumTrainData td = trainData.at(i);
memcpy(m.data, td.data, featureLen*sizeof(float));
normalize(m, m);
memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float));
res.at<unsigned int>(i, 0) = td.result;
}
/////////////START SVM TRAINNING//////////////////
CvSVM svm = CvSVM();
CvSVMParams param;
CvTermCriteria criteria;
criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);
param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);
svm.train(data, res, Mat(), Mat(), param);
svm.save(svmDataFile.c_str());
}
int newSvmPredict()
{
ifstream lab_ifs(labelFileName);
if(!lab_ifs) return -1;
int sn, lab;
Mat m = Mat::zeros(1, featureLen, CV_32FC1);
CvSVM svm = CvSVM();
svm.load(svmDataFile.c_str());
while (lab_ifs >> sn >> lab, !lab_ifs.eof()) {
if(lab_ifs.fail()) {
cout << "failed to read data from file." << endl;
lab_ifs.close();
return -1;
}
if(lab < 1 || lab > 9) {
cout << "label is invalid. skipped." << endl;
continue;
}
ostringstream oss;
oss << fileDir << sn << ".png";
Mat temp = imread(oss.str(), 0);
if(temp.empty()) {
cout << "failed to read sample #" << sn << ". skipped." << endl;
continue;
}
for(int i = 0; i < size; i++)
{
for(int j = 0; j < size; j++)
{
m.at<float>(0, i*size+j) = temp.at<uchar>(i, j);
}
}
normalize(m, m);
int ret = int(svm.predict(m));
stringstream ss;
ss << "Number " << lab << ", predict " << ret;
Mat img = Mat::zeros(temp.rows*30, temp.cols*30, CV_8UC3);
resize(temp, img, img.size());
putText(img, ss.str(), Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, Scalar(0, 0, 255));
imshow("img", img);
waitKey();
}
lab_ifs.close();
return 0;
}
int main( int argc, char *argv[] )
{
#ifdef ON_STUDY
ReadTrainData();
newSvmStudy(buffer);
#else
newSvmPredict();
#endif
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment