Last active
November 28, 2017 02:48
-
-
Save ruofeidu/7ed4652fcc1f41b0eb08d85ab2405dc1 to your computer and use it in GitHub Desktop.
MNIST C++ IO Helper Header
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #pragma once | |
| #include <string> | |
| #include <opencv/cv.hpp> | |
| using namespace cv; | |
| using namespace std; | |
| uchar** read_mnist_images(string full_path, int& number_of_images, int& image_size) { | |
| auto reverseInt = [](int i) { | |
| unsigned char c1, c2, c3, c4; | |
| c1 = i & 255, c2 = (i >> 8) & 255, c3 = (i >> 16) & 255, c4 = (i >> 24) & 255; | |
| return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4; | |
| }; | |
| typedef unsigned char uchar; | |
| ifstream file(full_path, ios::binary); | |
| if (file.is_open()) { | |
| int magic_number = 0, n_rows = 0, n_cols = 0; | |
| file.read((char *)&magic_number, sizeof(magic_number)); | |
| magic_number = reverseInt(magic_number); | |
| if (magic_number != 2051) throw runtime_error("Invalid MNIST image file!"); | |
| file.read((char *)&number_of_images, sizeof(number_of_images)), number_of_images = reverseInt(number_of_images); | |
| file.read((char *)&n_rows, sizeof(n_rows)), n_rows = reverseInt(n_rows); | |
| file.read((char *)&n_cols, sizeof(n_cols)), n_cols = reverseInt(n_cols); | |
| image_size = n_rows * n_cols; | |
| uchar** _dataset = new uchar*[number_of_images]; | |
| for (int i = 0; i < number_of_images; i++) { | |
| _dataset[i] = new uchar[image_size]; | |
| file.read((char *)_dataset[i], image_size); | |
| } | |
| return _dataset; | |
| } | |
| else { | |
| throw runtime_error("Cannot open file `" + full_path + "`!"); | |
| } | |
| } | |
| bool write_mnist_perlin_images(string full_path, uchar** _dataset, int& number_of_images, int& image_size) { | |
| typedef unsigned char uchar; | |
| ofstream file(full_path, ios::binary); | |
| if (file.is_open()) { | |
| for (int i = 0; i < number_of_images; i++) { | |
| file.write((char *)_dataset[i], image_size); | |
| } | |
| return true; | |
| } | |
| else { | |
| throw runtime_error("Cannot open file `" + full_path + "`!"); | |
| } | |
| } | |
| uchar** read_mnist_perlin_images(string _full_path) { | |
| const int number_of_images = 10000; | |
| const int image_size = 28 * 28; | |
| ifstream file(_full_path, ios::binary); | |
| if (file.is_open()) { | |
| uchar** _dataset = new uchar*[number_of_images]; | |
| for (int i = 0; i < number_of_images; i++) { | |
| _dataset[i] = new uchar[image_size]; | |
| file.read((char *)_dataset[i], image_size); | |
| } | |
| return _dataset; | |
| } | |
| else { | |
| throw runtime_error("Cannot open file `" + _full_path + "`!"); | |
| } | |
| } | |
| uchar convertImage(uchar** imgs, uchar* labels, int id, int image_size, Mat &img) { | |
| int w = int(round(sqrt(image_size))); | |
| img = Mat(w, w, CV_8UC1, imgs[id]); | |
| return labels[id]; | |
| } | |
| uchar* read_mnist_labels(string full_path, int& number_of_labels) { | |
| auto reverseInt = [](int i) { | |
| unsigned char c1, c2, c3, c4; | |
| c1 = i & 255, c2 = (i >> 8) & 255, c3 = (i >> 16) & 255, c4 = (i >> 24) & 255; | |
| return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4; | |
| }; | |
| typedef unsigned char uchar; | |
| ifstream file(full_path, ios::binary); | |
| if (file.is_open()) { | |
| int magic_number = 0; | |
| file.read((char *)&magic_number, sizeof(magic_number)); | |
| magic_number = reverseInt(magic_number); | |
| if (magic_number != 2049) throw runtime_error("Invalid MNIST label file!"); | |
| file.read((char *)&number_of_labels, sizeof(number_of_labels)), number_of_labels = reverseInt(number_of_labels); | |
| uchar* _dataset = new uchar[number_of_labels]; | |
| for (int i = 0; i < number_of_labels; i++) { | |
| file.read((char*)&_dataset[i], 1); | |
| } | |
| return _dataset; | |
| } | |
| else { | |
| throw runtime_error("Unable to open file `" + full_path + "`!"); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment