Created
November 18, 2013 12:18
-
-
Save liberize/7526893 to your computer and use it in GitHub Desktop.
OpenCV: Sudoku Number Training
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <fstream> | |
#include "opencv2/opencv.hpp" | |
#include <vector> | |
#include <io.h> | |
using namespace std; | |
using namespace cv; | |
// #define ON_STUDY | |
const int size = 16; | |
const int featureLen = size*size; | |
const string fileDir = "../samples/"; | |
const string labelFileName = "../label.txt"; | |
const string svmDataFile = "../SVM_DATA.xml"; | |
class NumTrainData | |
{ | |
public: | |
NumTrainData() | |
{ | |
memset(data, 0, sizeof(data)); | |
result = -1; | |
} | |
public: | |
float data[size*size]; | |
int result; | |
}; | |
vector<NumTrainData> buffer; | |
int ReadTrainData(int maxCount = INT_MAX) | |
{ | |
// Open image and label file | |
ifstream lab_ifs(labelFileName); | |
if(!lab_ifs) return -1; | |
int sn, lab; | |
NumTrainData rtd; | |
int total = 0; | |
while (lab_ifs >> sn >> lab, !lab_ifs.eof()) { | |
if(++total > maxCount) break; | |
cout << total << endl; | |
if(lab_ifs.fail()) { | |
cout << "failed to read data from file." << endl; | |
lab_ifs.close(); | |
return -1; | |
} | |
if(lab < 1 || lab > 9) { | |
cout << "label is invalid. skipped." << endl; | |
continue; | |
} | |
ostringstream oss; | |
oss << fileDir << sn << ".png"; | |
Mat temp = imread(oss.str(), 0); | |
if(temp.empty()) { | |
cout << "failed to read sample #" << sn << ". skipped." << endl; | |
continue; | |
} | |
rtd.result = lab; | |
for(int i = 0; i < size; i++) | |
{ | |
for(int j = 0; j < size; j++) | |
{ | |
rtd.data[i*size+j] = temp.at<uchar>(i, j); | |
} | |
} | |
buffer.push_back(rtd); | |
} | |
lab_ifs.close(); | |
return 0; | |
} | |
void newSvmStudy(vector<NumTrainData>& trainData) | |
{ | |
int testCount = trainData.size(); | |
Mat m = Mat::zeros(1, featureLen, CV_32FC1); | |
Mat data = Mat::zeros(testCount, featureLen, CV_32FC1); | |
Mat res = Mat::zeros(testCount, 1, CV_32SC1); | |
for (int i= 0; i< testCount; i++) | |
{ | |
NumTrainData td = trainData.at(i); | |
memcpy(m.data, td.data, featureLen*sizeof(float)); | |
normalize(m, m); | |
memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float)); | |
res.at<unsigned int>(i, 0) = td.result; | |
} | |
/////////////START SVM TRAINNING////////////////// | |
CvSVM svm = CvSVM(); | |
CvSVMParams param; | |
CvTermCriteria criteria; | |
criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON); | |
param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria); | |
svm.train(data, res, Mat(), Mat(), param); | |
svm.save(svmDataFile.c_str()); | |
} | |
int newSvmPredict() | |
{ | |
ifstream lab_ifs(labelFileName); | |
if(!lab_ifs) return -1; | |
int sn, lab; | |
Mat m = Mat::zeros(1, featureLen, CV_32FC1); | |
CvSVM svm = CvSVM(); | |
svm.load(svmDataFile.c_str()); | |
while (lab_ifs >> sn >> lab, !lab_ifs.eof()) { | |
if(lab_ifs.fail()) { | |
cout << "failed to read data from file." << endl; | |
lab_ifs.close(); | |
return -1; | |
} | |
if(lab < 1 || lab > 9) { | |
cout << "label is invalid. skipped." << endl; | |
continue; | |
} | |
ostringstream oss; | |
oss << fileDir << sn << ".png"; | |
Mat temp = imread(oss.str(), 0); | |
if(temp.empty()) { | |
cout << "failed to read sample #" << sn << ". skipped." << endl; | |
continue; | |
} | |
for(int i = 0; i < size; i++) | |
{ | |
for(int j = 0; j < size; j++) | |
{ | |
m.at<float>(0, i*size+j) = temp.at<uchar>(i, j); | |
} | |
} | |
normalize(m, m); | |
int ret = int(svm.predict(m)); | |
stringstream ss; | |
ss << "Number " << lab << ", predict " << ret; | |
Mat img = Mat::zeros(temp.rows*30, temp.cols*30, CV_8UC3); | |
resize(temp, img, img.size()); | |
putText(img, ss.str(), Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, Scalar(0, 0, 255)); | |
imshow("img", img); | |
waitKey(); | |
} | |
lab_ifs.close(); | |
return 0; | |
} | |
int main( int argc, char *argv[] ) | |
{ | |
#ifdef ON_STUDY | |
ReadTrainData(); | |
newSvmStudy(buffer); | |
#else | |
newSvmPredict(); | |
#endif | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment