Created
August 6, 2017 11:28
-
-
Save johnhany/a48487dcacdb4c2108e919b421d19cfb to your computer and use it in GitHub Desktop.
KNN on MNIST with OpenCV
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <opencv2/opencv.hpp> | |
using namespace std; | |
using namespace cv; | |
#define PIXELS_IN_IMAGE 28*28 | |
#define ENABLE_TRAIN 1 | |
int reverseInt(int i); | |
int loadMNIST(const string pic_filename, const string label_filename, Mat& training_data, Mat& label_data); | |
int main(int argc, char* argv[]) | |
{ | |
try | |
{ | |
Mat train_data_mat, train_label_mat; | |
Mat test_data_mat, test_label_mat; | |
loadMNIST("F:\\Datasets\\mnist\\train-images.idx3-ubyte", "F:\\Datasets\\mnist\\train-labels.idx1-ubyte", train_data_mat, train_label_mat); | |
loadMNIST("F:\\Datasets\\mnist\\t10k-images.idx3-ubyte", "F:\\Datasets\\mnist\\t10k-labels.idx1-ubyte", test_data_mat, test_label_mat); | |
train_data_mat.convertTo(train_data_mat, CV_32FC1); | |
train_label_mat.convertTo(train_label_mat, CV_32SC1); | |
test_data_mat.convertTo(test_data_mat, CV_32FC1); | |
Ptr<ml::KNearest> knn; | |
knn = ml::KNearest::create(); | |
knn->setDefaultK(10); | |
knn->train(train_data_mat, ml::SampleTypes::ROW_SAMPLE, train_label_mat); | |
int correct_count = 0; | |
for (int idx = 0; idx < test_label_mat.rows; idx++) { | |
Mat result_mat; | |
float response = knn->findNearest(test_data_mat.row(idx), knn->getDefaultK(), result_mat); | |
if(test_label_mat.at<uchar>(idx, 0) == (uchar)response) { | |
correct_count++; | |
} | |
} | |
double correct_ratio = (double)correct_count / (double)test_label_mat.rows; | |
cout << correct_ratio << endl; | |
} | |
catch (const Exception& ex) | |
{ | |
cout << "Error: " << ex.what() << endl; | |
} | |
cin.get(); | |
return 0; | |
} | |
int reverseInt(int i) { | |
unsigned char ch1, ch2, ch3, ch4; | |
ch1 = i & 255; | |
ch2 = (i >> 8) & 255; | |
ch3 = (i >> 16) & 255; | |
ch4 = (i >> 24) & 255; | |
return ((int)ch1 << 24) + ((int)ch2 << 16) + ((int)ch3 << 8) + ch4; | |
} | |
int loadMNIST(const string pic_filename, const string label_filename, Mat& training_data, Mat& label_data) { | |
std::ifstream pic_file(pic_filename, std::ios::binary); | |
std::ifstream label_file(label_filename, std::ios::binary); | |
if (pic_file.is_open() && label_file.is_open()) { | |
int magic_number = 0; | |
int number_of_images = 0; | |
int n_rows = 0; | |
int n_cols = 0; | |
label_file.read((char*)&magic_number, sizeof(magic_number)); | |
pic_file.read((char*)&magic_number, sizeof(magic_number)); | |
magic_number = reverseInt(magic_number); | |
label_file.read((char*)&number_of_images, sizeof(number_of_images)); | |
pic_file.read((char*)&number_of_images, sizeof(number_of_images)); | |
number_of_images = reverseInt(number_of_images); | |
pic_file.read((char*)&n_rows, sizeof(n_rows)); | |
n_rows = reverseInt(n_rows); | |
pic_file.read((char*)&n_cols, sizeof(n_cols)); | |
n_cols = reverseInt(n_cols); | |
int n_stride = n_cols * n_rows; | |
training_data = Mat(number_of_images, n_stride, CV_8U); | |
label_data = Mat(number_of_images, 1, CV_8U); | |
for (int i = 0; i < number_of_images; ++i) { | |
// for (int i = 0; i < 5000; ++i) { | |
unsigned char data_tmp[PIXELS_IN_IMAGE]; | |
pic_file.read((char*)data_tmp, sizeof(unsigned char) * n_stride); | |
Mat row_image(1, n_stride, CV_8U, data_tmp); | |
row_image.row(0).copyTo(training_data.row(i)); | |
char label = 0; | |
label_file.read((char*)&label, sizeof(label)); | |
label_data.at<uchar>(i, 0) = label; | |
} | |
} else { | |
return 1; | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment