Skip to content

Instantly share code, notes, and snippets.

@LaurentBerger
Created December 6, 2017 14:03
Show Gist options
  • Save LaurentBerger/dbcd31f253c3bec7f842e892e4f84576 to your computer and use it in GitHub Desktop.
Save LaurentBerger/dbcd31f253c3bec7f842e892e4f84576 to your computer and use it in GitHub Desktop.
#include <opencv2/opencv.hpp>
#include <iostream>
#include <map>
#include <opencv2/ts.hpp>
#include "opencv2/ml.hpp"
#include "opencv2/core/core_c.h"
using namespace std;
using namespace cv;
#define CV_ANN "ann"
enum { CV_TRAIN_ERROR = 0, CV_TEST_ERROR = 1 };
using cv::Ptr;
using cv::ml::StatModel;
using cv::ml::TrainData;
using cv::ml::NormalBayesClassifier;
using cv::ml::SVM;
using cv::ml::KNearest;
using cv::ml::ParamGrid;
using cv::ml::ANN_MLP;
using cv::ml::DTrees;
using cv::ml::Boost;
using cv::ml::RTrees;
using cv::ml::SVMSGD;
class CV_MLBaseTest : public cvtest::BaseTest
{
public:
CV_MLBaseTest(const char* _modelName);
virtual ~CV_MLBaseTest();
protected:
virtual int read_params(CvFileStorage* fs);
virtual void run(int startFrom);
virtual int prepare_test_case(int testCaseIdx);
virtual std::string& get_validation_filename();
virtual int run_test_case(int testCaseIdx) = 0;
virtual int validate_test_results(int testCaseIdx) = 0;
int train(int testCaseIdx);
float get_test_error(int testCaseIdx, std::vector<float> *resp = 0);
void save(const char* filename);
void load(const char* filename);
Ptr<TrainData> data;
std::string modelName, validationFN;
std::vector<std::string> dataSetNames;
cv::FileStorage validationFS;
Ptr<StatModel> model;
std::map<int, int> cls_map;
int64 initSeed;
};
// ---------------------------------- MLBaseTest ---------------------------------------------------
void ann_check_data(Ptr<TrainData> _data)
{
CV_TRACE_FUNCTION();
Mat values = _data->getSamples();
Mat var_idx = _data->getVarIdx();
int nvars = (int)var_idx.total();
if (nvars != 0 && nvars != values.cols)
CV_Error(CV_StsBadArg, "var_idx is not supported");
if (!_data->getMissing().empty())
CV_Error(CV_StsBadArg, "missing values are not supported");
}
int str_to_ann_train_method(String& str)
{
if (!str.compare("BACKPROP"))
return ANN_MLP::BACKPROP;
if (!str.compare("RPROP"))
return ANN_MLP::RPROP;
if (!str.compare("ANNEAL"))
return ANN_MLP::ANNEAL;
CV_Error(CV_StsBadArg, "incorrect ann train method string");
return -1;
}
int str_to_ann_activation_function(String& str)
{
if (!str.compare("IDENTITY"))
return ANN_MLP::IDENTITY;
if (!str.compare("SIGMOID_SYM"))
return ANN_MLP::SIGMOID_SYM;
if (!str.compare("GAUSSIAN"))
return ANN_MLP::GAUSSIAN;
if (!str.compare("RELU"))
return ANN_MLP::RELU;
if (!str.compare("LEAKYRELU"))
return ANN_MLP::LEAKYRELU;
CV_Error(CV_StsBadArg, "incorrect ann activation function string");
return -1;
}
// unroll the categorical responses to binary vectors
Mat ann_get_new_responses(Ptr<TrainData> _data, map<int, int>& cls_map)
{
CV_TRACE_FUNCTION();
Mat train_sidx = _data->getTrainSampleIdx();
int* train_sidx_ptr = train_sidx.ptr<int>();
Mat responses = _data->getResponses();
int cls_count = 0;
// construct cls_map
cls_map.clear();
int nresponses = (int)responses.total();
int si, n = !train_sidx.empty() ? (int)train_sidx.total() : nresponses;
for (si = 0; si < n; si++)
{
int sidx = train_sidx_ptr ? train_sidx_ptr[si] : si;
int r = cvRound(responses.at<float>(sidx));
CV_DbgAssert(fabs(responses.at<float>(sidx) - r) < FLT_EPSILON);
map<int, int>::iterator it = cls_map.find(r);
if (it == cls_map.end())
cls_map[r] = cls_count++;
}
Mat new_responses = Mat::zeros(nresponses, cls_count, CV_32F);
for (si = 0; si < n; si++)
{
int sidx = train_sidx_ptr ? train_sidx_ptr[si] : si;
int r = cvRound(responses.at<float>(sidx));
int cidx = cls_map[r];
new_responses.at<float>(sidx, cidx) = 1.f;
}
return new_responses;
}
float ann_calc_error(Ptr<StatModel> ann, Ptr<TrainData> _data, map<int, int>& cls_map, int type, vector<float> *resp_labels)
{
CV_TRACE_FUNCTION();
float err = 0;
Mat samples = _data->getSamples();
Mat responses = _data->getResponses();
Mat sample_idx = (type == CV_TEST_ERROR) ? _data->getTestSampleIdx() : _data->getTrainSampleIdx();
int* sidx = !sample_idx.empty() ? sample_idx.ptr<int>() : 0;
ann_check_data(_data);
int sample_count = (int)sample_idx.total();
sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? samples.rows : sample_count;
float* pred_resp = 0;
vector<float> innresp;
if (sample_count > 0)
{
if (resp_labels)
{
resp_labels->resize(sample_count);
pred_resp = &((*resp_labels)[0]);
}
else
{
innresp.resize(sample_count);
pred_resp = &(innresp[0]);
}
}
int cls_count = (int)cls_map.size();
Mat output(1, cls_count, CV_32FC1);
for (int i = 0; i < sample_count; i++)
{
int si = sidx ? sidx[i] : i;
Mat sample = samples.row(si);
ann->predict(sample, output);
Point best_cls;
minMaxLoc(output, 0, 0, 0, &best_cls, 0);
int r = cvRound(responses.at<float>(si));
CV_DbgAssert(fabs(responses.at<float>(si) - r) < FLT_EPSILON);
r = cls_map[r];
int d = best_cls.x == r ? 0 : 1;
err += d;
pred_resp[i] = (float)best_cls.x;
}
err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
return err;
}
CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
{
int64 seeds[] = { CV_BIG_INT(0x00009fff4f9c8d52),
CV_BIG_INT(0x0000a17166072c7c),
CV_BIG_INT(0x0201b32115cd1f9a),
CV_BIG_INT(0x0513cb37abcd1234),
CV_BIG_INT(0x0001a2b3c4d5f678)
};
int seedCount = sizeof(seeds) / sizeof(seeds[0]);
RNG& rng = theRNG();
setUseOptimized(false);
initSeed = rng.state;
rng.state = seeds[rng(seedCount)];
modelName = _modelName;
}
CV_MLBaseTest::~CV_MLBaseTest()
{
if (validationFS.isOpened())
validationFS.release();
theRNG().state = initSeed;
}
int CV_MLBaseTest::read_params(CvFileStorage* __fs)
{
CV_TRACE_FUNCTION();
FileStorage _fs(__fs, false);
if (!_fs.isOpened())
test_case_count = -1;
else
{
FileNode fn = _fs.getFirstTopLevelNode()["run_params"][modelName];
test_case_count = (int)fn.size();
if (test_case_count <= 0)
test_case_count = -1;
if (test_case_count > 0)
{
dataSetNames.resize(test_case_count);
FileNodeIterator it = fn.begin();
for (int i = 0; i < test_case_count; i++, ++it)
{
dataSetNames[i] = (string)*it;
}
}
}
return cvtest::TS::OK;;
}
void CV_MLBaseTest::run(int)
{
CV_TRACE_FUNCTION();
string filename = ts->get_data_path();
filename += get_validation_filename();
validationFS.open(filename, FileStorage::READ);
read_params(*validationFS);
int code = cvtest::TS::OK;
for (int i = 0; i < test_case_count; i++)
{
CV_TRACE_REGION("iteration");
int temp_code = run_test_case(i);
if (temp_code == cvtest::TS::OK)
temp_code = validate_test_results(i);
if (temp_code != cvtest::TS::OK)
code = temp_code;
}
if (test_case_count <= 0)
{
ts->printf(cvtest::TS::LOG, "validation file is not determined or not correct");
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
}
ts->set_failed_test_info(code);
}
int CV_MLBaseTest::prepare_test_case(int test_case_idx)
{
CV_TRACE_FUNCTION();
clear();
string dataPath = ts->get_data_path();
if (dataPath.empty())
{
ts->printf(cvtest::TS::LOG, "data path is empty");
return cvtest::TS::FAIL_INVALID_TEST_DATA;
}
string dataName = dataSetNames[test_case_idx],
filename = dataPath + dataName + ".data";
FileNode dataParamsNode = validationFS.getFirstTopLevelNode()["validation"][modelName][dataName]["data_params"];
CV_DbgAssert(!dataParamsNode.empty());
CV_DbgAssert(!dataParamsNode["LS"].empty());
int trainSampleCount = (int)dataParamsNode["LS"];
CV_DbgAssert(!dataParamsNode["resp_idx"].empty());
int respIdx = (int)dataParamsNode["resp_idx"];
CV_DbgAssert(!dataParamsNode["types"].empty());
String varTypes = (String)dataParamsNode["types"];
data = TrainData::loadFromCSV(filename, 0, respIdx, respIdx + 1, varTypes);
if (data.empty())
{
ts->printf(cvtest::TS::LOG, "file %s can not be read\n", filename.c_str());
return cvtest::TS::FAIL_INVALID_TEST_DATA;
}
data->setTrainTestSplit(trainSampleCount);
return cvtest::TS::OK;
}
string& CV_MLBaseTest::get_validation_filename()
{
return validationFN;
}
int CV_MLBaseTest::train(int testCaseIdx)
{
CV_TRACE_FUNCTION();
bool is_trained = false;
FileNode modelParamsNode =
validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"];
if (modelName == CV_ANN)
{
String train_method_str, activation_function_str;
double param1, param2;
modelParamsNode["train_method"] >> train_method_str;
modelParamsNode["param1"] >> param1;
modelParamsNode["param2"] >> param2;
Mat new_responses = ann_get_new_responses(data, cls_map);
// binarize the responses
data = TrainData::create(data->getSamples(), data->getLayout(), new_responses,
data->getVarIdx(), data->getTrainSampleIdx());
int layer_sz[] = { data->getNAllVars(), 100, 100, (int)cls_map.size() };
Mat layer_sizes(1, (int)(sizeof(layer_sz) / sizeof(layer_sz[0])), CV_32S, layer_sz);
Ptr<ANN_MLP> m = ANN_MLP::create();
m->setLayerSizes(layer_sizes);
m->setActivationFunction(ANN_MLP::SIGMOID_SYM, 0, 0);
m->setTermCriteria(TermCriteria(TermCriteria::COUNT, 300, 0.01));
m->setTrainMethod(str_to_ann_train_method(train_method_str), param1, param2);
model = m;
}
if (!model.empty())
is_trained = model->train(data, 0);
if (!is_trained)
{
ts->printf(cvtest::TS::LOG, "in test case %d model training was failed", testCaseIdx);
return cvtest::TS::FAIL_INVALID_OUTPUT;
}
return cvtest::TS::OK;
}
float CV_MLBaseTest::get_test_error(int /*testCaseIdx*/, vector<float> *resp)
{
CV_TRACE_FUNCTION();
int type = CV_TEST_ERROR;
float err = 0;
Mat _resp;
if (modelName == CV_ANN)
err = ann_calc_error(model, data, cls_map, type, resp);
if (!_resp.empty() && resp)
_resp.convertTo(*resp, CV_32F);
return err;
}
void CV_MLBaseTest::save(const char* filename)
{
CV_TRACE_FUNCTION();
model->save(filename);
}
void CV_MLBaseTest::load(const char* filename)
{
CV_TRACE_FUNCTION();
if (modelName == CV_ANN)
model = Algorithm::load<ANN_MLP>(filename);
else
CV_Error(CV_StsNotImplemented, "invalid stat model name");
}
class CV_AMLTest : public CV_MLBaseTest
{
public:
CV_AMLTest(const char* _modelName);
virtual ~CV_AMLTest() {}
protected:
virtual int run_test_case(int testCaseIdx);
virtual int validate_test_results(int testCaseIdx);
};
class CV_SLMLTest : public CV_MLBaseTest
{
public:
CV_SLMLTest(const char* _modelName);
virtual ~CV_SLMLTest() {}
protected:
virtual int run_test_case(int testCaseIdx);
virtual int validate_test_results(int testCaseIdx);
std::vector<float> test_resps1, test_resps2; // predicted responses for test data
std::string fname1, fname2;
};
using namespace std;
using namespace cv;
CV_AMLTest::CV_AMLTest(const char* _modelName) : CV_MLBaseTest(_modelName)
{
validationFN = "avalidation.xml";
}
int CV_AMLTest::run_test_case(int testCaseIdx)
{
CV_TRACE_FUNCTION();
int code = cvtest::TS::OK;
code = prepare_test_case(testCaseIdx);
if (code == cvtest::TS::OK)
{
//#define GET_STAT
#ifdef GET_STAT
const char* data_name = ((CvFileNode*)cvGetSeqElem(dataSetNames, testCaseIdx))->data.str.ptr;
printf("%s, %s ", name, data_name);
const int icount = 100;
float res[icount];
for (int k = 0; k < icount; k++)
{
#endif
data->shuffleTrainTest();
code = train(testCaseIdx);
#ifdef GET_STAT
float case_result = get_error();
res[k] = case_result;
}
float mean = 0, sigma = 0;
for (int k = 0; k < icount; k++)
{
mean += res[k];
}
mean = mean / icount;
for (int k = 0; k < icount; k++)
{
sigma += (res[k] - mean)*(res[k] - mean);
}
sigma = sqrt(sigma / icount);
printf("%f, %f\n", mean, sigma);
#endif
}
return code;
}
int CV_AMLTest::validate_test_results(int testCaseIdx)
{
CV_TRACE_FUNCTION();
int iters;
float mean, sigma;
// read validation params
FileNode resultNode =
validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["result"];
resultNode["iter_count"] >> iters;
if (iters > 0)
{
resultNode["mean"] >> mean;
resultNode["sigma"] >> sigma;
model->save(format("/Users/vp/tmp/dtree/testcase_%02d.cur.yml", testCaseIdx));
float curErr = get_test_error(testCaseIdx);
const int coeff = 4;
ts->printf(cvtest::TS::LOG, "Test case = %d; test error = %f; mean error = %f (diff=%f), %d*sigma = %f\n",
testCaseIdx, curErr, mean, abs(curErr - mean), coeff, coeff*sigma);
if (abs(curErr - mean) > coeff*sigma)
{
ts->printf(cvtest::TS::LOG, "abs(%f - %f) > %f - OUT OF RANGE!\n", curErr, mean, coeff*sigma, coeff);
return cvtest::TS::FAIL_BAD_ACCURACY;
}
else
ts->printf(cvtest::TS::LOG, ".\n");
}
else
{
ts->printf(cvtest::TS::LOG, "validation info is not suitable");
return cvtest::TS::FAIL_INVALID_TEST_DATA;
}
return cvtest::TS::OK;
}
/* End of file. */
CV_SLMLTest::CV_SLMLTest(const char* _modelName) : CV_MLBaseTest(_modelName)
{
validationFN = "slvalidation.xml";
}
int CV_SLMLTest::run_test_case(int testCaseIdx)
{
int code = cvtest::TS::OK;
code = prepare_test_case(testCaseIdx);
if (code == cvtest::TS::OK)
{
data->setTrainTestSplit(data->getNTrainSamples(), true);
code = train(testCaseIdx);
if (code == cvtest::TS::OK)
{
get_test_error(testCaseIdx, &test_resps1);
fname1 = tempfile(".json.gz");
save((fname1 + "?base64").c_str());
load(fname1.c_str());
get_test_error(testCaseIdx, &test_resps2);
fname2 = tempfile(".json.gz");
save((fname2 + "?base64").c_str());
}
else
ts->printf(cvtest::TS::LOG, "model can not be trained");
}
return code;
}
int CV_SLMLTest::validate_test_results(int testCaseIdx)
{
int code = cvtest::TS::OK;
// 1. compare files
FILE *fs1 = fopen(fname1.c_str(), "rb"), *fs2 = fopen(fname2.c_str(), "rb");
size_t sz1 = 0, sz2 = 0;
if (!fs1 || !fs2)
code = cvtest::TS::FAIL_MISSING_TEST_DATA;
if (code >= 0)
{
fseek(fs1, 0, SEEK_END); fseek(fs2, 0, SEEK_END);
sz1 = ftell(fs1);
sz2 = ftell(fs2);
fseek(fs1, 0, SEEK_SET); fseek(fs2, 0, SEEK_SET);
}
if (sz1 != sz2)
code = cvtest::TS::FAIL_INVALID_OUTPUT;
if (code >= 0)
{
const int BUFSZ = 1024;
uchar buf1[BUFSZ], buf2[BUFSZ];
for (size_t pos = 0; pos < sz1; )
{
size_t r1 = fread(buf1, 1, BUFSZ, fs1);
size_t r2 = fread(buf2, 1, BUFSZ, fs2);
if (r1 != r2 || memcmp(buf1, buf2, r1) != 0)
{
ts->printf(cvtest::TS::LOG,
"in test case %d first (%s) and second (%s) saved files differ in %d-th kb\n",
testCaseIdx, fname1.c_str(), fname2.c_str(),
(int)pos);
code = cvtest::TS::FAIL_INVALID_OUTPUT;
break;
}
pos += r1;
}
}
if (fs1)
fclose(fs1);
if (fs2)
fclose(fs2);
// delete temporary files
if (code >= 0)
{
remove(fname1.c_str());
remove(fname2.c_str());
}
if (code >= 0)
{
// 2. compare responses
CV_Assert(test_resps1.size() == test_resps2.size());
vector<float>::const_iterator it1 = test_resps1.begin(), it2 = test_resps2.begin();
for (; it1 != test_resps1.end(); ++it1, ++it2)
{
if (fabs(*it1 - *it2) > FLT_EPSILON)
{
ts->printf(cvtest::TS::LOG, "in test case %d responses predicted before saving and after loading is different", testCaseIdx);
code = cvtest::TS::FAIL_INVALID_OUTPUT;
break;
}
}
}
return code;
}
TEST(ML_ANN, save_load) { CV_SLMLTest test(CV_ANN); test.safe_run(); }
class CV_LegacyTest : public cvtest::BaseTest
{
public:
CV_LegacyTest(const std::string &_modelName, const std::string &_suffixes = std::string())
: cvtest::BaseTest(), modelName(_modelName), suffixes(_suffixes)
{
}
virtual ~CV_LegacyTest() {}
protected:
void run(int)
{
unsigned int idx = 0;
for (;;)
{
if (idx >= suffixes.size())
break;
int found = (int)suffixes.find(';', idx);
string piece = suffixes.substr(idx, found - idx);
if (piece.empty())
break;
oneTest(piece);
idx += (unsigned int)piece.size() + 1;
}
}
void oneTest(const string & suffix)
{
using namespace cv::ml;
int code = cvtest::TS::OK;
string filename = ts->get_data_path() + "legacy/" + modelName + suffix;
bool isTree = false;
Ptr<StatModel> model;
if (modelName == CV_ANN)
model = Algorithm::load<ANN_MLP>(filename);
if (!model)
{
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
}
else
{
Mat input = Mat(isTree ? 10 : 1, model->getVarCount(), CV_32F);
ts->get_rng().fill(input, RNG::UNIFORM, 0, 40);
if (isTree)
randomFillCategories(filename, input);
Mat output;
model->predict(input, output, StatModel::RAW_OUTPUT | (isTree ? DTrees::PREDICT_SUM : 0));
// just check if no internal assertions or errors thrown
}
ts->set_failed_test_info(code);
}
void randomFillCategories(const string & filename, Mat & input)
{
Mat catMap;
Mat catCount;
std::vector<uchar> varTypes;
FileStorage fs(filename, FileStorage::READ);
FileNode root = fs.getFirstTopLevelNode();
root["cat_map"] >> catMap;
root["cat_count"] >> catCount;
root["var_type"] >> varTypes;
int offset = 0;
int countOffset = 0;
uint var = 0, varCount = (uint)varTypes.size();
for (; var < varCount; ++var)
{
if (varTypes[var] == ml::VAR_CATEGORICAL)
{
int size = catCount.at<int>(0, countOffset);
for (int row = 0; row < input.rows; ++row)
{
int randomChosenIndex = offset + ((uint)ts->get_rng()) % size;
int value = catMap.at<int>(0, randomChosenIndex);
input.at<float>(row, var) = (float)value;
}
offset += size;
++countOffset;
}
}
}
string modelName;
string suffixes;
};
//TEST(ML_ANN, legacy_load) { CV_LegacyTest test(CV_ANN, "_waveform.xml"); test.safe_run(); }
/*TEST(ML_ANN, ActivationFunction)
{
String folder = string(cvtest::TS::ptr()->get_data_path());
String original_path = folder + "waveform.data";
String dataname = folder + "waveform";
Ptr<TrainData> tdata = TrainData::loadFromCSV(original_path, 0);
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path;
RNG& rng = theRNG();
rng.state = 1027401484159173092;
tdata->setTrainTestSplit(500);
vector<int> activationType;
activationType.push_back(ml::ANN_MLP::IDENTITY);
activationType.push_back(ml::ANN_MLP::SIGMOID_SYM);
activationType.push_back(ml::ANN_MLP::GAUSSIAN);
activationType.push_back(ml::ANN_MLP::RELU);
activationType.push_back(ml::ANN_MLP::LEAKYRELU);
vector<String> activationName;
activationName.push_back("_identity");
activationName.push_back("_sigmoid_sym");
activationName.push_back("_gaussian");
activationName.push_back("_relu");
activationName.push_back("_leakyrelu");
for (size_t i = 0; i < activationType.size(); i++)
{
Ptr<ml::ANN_MLP> x = ml::ANN_MLP::create();
Mat_<int> layerSizes(1, 4);
layerSizes(0, 0) = tdata->getNVars();
layerSizes(0, 1) = 100;
layerSizes(0, 2) = 100;
layerSizes(0, 3) = tdata->getResponses().cols;
x->setLayerSizes(layerSizes);
x->setActivationFunction(activationType[i]);
x->setTrainMethod(ml::ANN_MLP::RPROP, 0.01, 0.1);
x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 300, 0.01));
x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE);
ASSERT_TRUE(x->isTrained()) << "Could not train networks with " << activationName[i];
#ifdef GENERATE_TESTDATA
x->save(dataname + activationName[i] + ".yml");
#else
Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(dataname + activationName[i] + ".yml");
ASSERT_TRUE(y != NULL) << "Could not load " << dataname + activationName[i] + ".yml";
Mat testSamples = tdata->getTestSamples();
Mat rx, ry, dst;
x->predict(testSamples, rx);
y->predict(testSamples, ry);
double n=cvtest::norm(rx,ry, NORM_INF);
ASSERT_TRUE(n<FLT_EPSILON) << "Predict are not equal for " << dataname + activationName[i] + ".yml and " << activationName[i];
#endif
}
}*/
//#define GENERATE_TESTDATA
TEST(ML_ANN, Method)
{
String folder = string(cvtest::TS::ptr()->get_data_path());
String original_path = folder + "waveform.data";
String dataname = folder + "waveform";
Ptr<TrainData> tdata2 = TrainData::loadFromCSV(original_path, 0);
Mat responses(tdata2->getResponses().rows, 3, CV_32FC1, Scalar(0));
for (int i = 0; i<tdata2->getResponses().rows; i++)
responses.at<float>(i, static_cast<int>(tdata2->getResponses().at<float>(i, 0))) = 1;
Ptr<TrainData> tdata = TrainData::create(tdata2->getSamples(), ml::ROW_SAMPLE, responses);
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path;
RNG& rng = theRNG();
rng.state = 1027401484159173092;
tdata->setTrainTestSplitRatio(0.01);
vector<int> methodType;
methodType.push_back(ml::ANN_MLP::RPROP);
methodType.push_back(ml::ANN_MLP::ANNEAL);
methodType.push_back(ml::ANN_MLP::BACKPROP);
vector<String> methodName;
methodName.push_back("_rprop");
methodName.push_back("_anneal");
methodName.push_back("_backprop");
#ifdef GENERATE_TESTDATA
rng.state = 1027401484159173092;
Ptr<ml::ANN_MLP> xx = ml::ANN_MLP::create();
Mat_<int> layerSizesXX(1, 3);
layerSizesXX(0, 0) = tdata->getNVars();
layerSizesXX(0, 1) = 30;
layerSizesXX(0, 2) = tdata->getResponses().cols;
xx->setLayerSizes(layerSizesXX);
xx->setActivationFunction(ml::ANN_MLP::SIGMOID_SYM);
xx->setTrainMethod(ml::ANN_MLP::RPROP);
xx->setTermCriteria(TermCriteria(TermCriteria::COUNT, 1, 0.01));
xx->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE);
FileStorage fs;
fs.open(dataname + "_init_weight.yml", FileStorage::WRITE + FileStorage::BASE64);
// xx->save(dataname + "_init_weight.yml");
xx->write(fs);
fs.release();
#endif
for (size_t i = 0; i < methodType.size(); i++)
{
rng.state = 1027401484159173092;
FileStorage fs;
fs.open(dataname + "_init_weight.yml", FileStorage::READ + FileStorage::BASE64);
Ptr<ml::ANN_MLP> x = ml::ANN_MLP::create();
x->read(fs.root());
x->setTrainMethod(methodType[i]);
if (methodType[i] == ml::ANN_MLP::BACKPROP)
{
x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 1, 0.01));
for (int jj = 0; jj<10; jj++)
{
x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::UPDATE_WEIGHTS);
FileStorage fs;
fs.open(format("%s_%s_%db64.yml", dataname.c_str(), methodName[i].c_str(), jj), FileStorage::WRITE + FileStorage::BASE64);
x->write(fs);
fs.release();
x->save(format("%s_%s_%d.yml", dataname.c_str(), methodName[i].c_str(), jj));
}
}
else
{
x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 10, 0.01));
x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::UPDATE_WEIGHTS);
}
ASSERT_TRUE(x->isTrained()) << "Could not train networks with " << methodName[i];
#ifdef GENERATE_TESTDATA
x->save(dataname + methodName[i] + ".yml");
#else
Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(dataname + methodName[i] + ".yml");
ASSERT_TRUE(y != NULL) << "Could not load " << dataname + methodName[i] + ".yml";
Mat testSamples = tdata->getTestSamples();
Mat rx, ry, dst;
for (int j = 0; j < 4; j++)
{
rx= x->getWeights(j);
ry= y->getWeights(j);
double n = cvtest::norm(rx, ry, NORM_INF);
EXPECT_LT(n, FLT_EPSILON) << "Weights are not equal for " << dataname + methodName[i] + ".yml and " << methodName[i]<< " layer : "<<j;
}
x->predict(testSamples, rx);
y->predict(testSamples, ry);
double n = cvtest::norm(rx, ry, NORM_INF);
EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal for " << dataname + methodName[i] + ".yml and " << methodName[i];
#endif
}
}
/* End of file. */
CV_TEST_MAIN("ml")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment