Created
December 6, 2017 14:03
-
-
Save LaurentBerger/dbcd31f253c3bec7f842e892e4f84576 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <opencv2/opencv.hpp> | |
#include <iostream> | |
#include <map> | |
#include <opencv2/ts.hpp> | |
#include "opencv2/ml.hpp" | |
#include "opencv2/core/core_c.h" | |
using namespace std; | |
using namespace cv; | |
#define CV_ANN "ann" | |
enum { CV_TRAIN_ERROR = 0, CV_TEST_ERROR = 1 }; | |
using cv::Ptr; | |
using cv::ml::StatModel; | |
using cv::ml::TrainData; | |
using cv::ml::NormalBayesClassifier; | |
using cv::ml::SVM; | |
using cv::ml::KNearest; | |
using cv::ml::ParamGrid; | |
using cv::ml::ANN_MLP; | |
using cv::ml::DTrees; | |
using cv::ml::Boost; | |
using cv::ml::RTrees; | |
using cv::ml::SVMSGD; | |
class CV_MLBaseTest : public cvtest::BaseTest | |
{ | |
public: | |
CV_MLBaseTest(const char* _modelName); | |
virtual ~CV_MLBaseTest(); | |
protected: | |
virtual int read_params(CvFileStorage* fs); | |
virtual void run(int startFrom); | |
virtual int prepare_test_case(int testCaseIdx); | |
virtual std::string& get_validation_filename(); | |
virtual int run_test_case(int testCaseIdx) = 0; | |
virtual int validate_test_results(int testCaseIdx) = 0; | |
int train(int testCaseIdx); | |
float get_test_error(int testCaseIdx, std::vector<float> *resp = 0); | |
void save(const char* filename); | |
void load(const char* filename); | |
Ptr<TrainData> data; | |
std::string modelName, validationFN; | |
std::vector<std::string> dataSetNames; | |
cv::FileStorage validationFS; | |
Ptr<StatModel> model; | |
std::map<int, int> cls_map; | |
int64 initSeed; | |
}; | |
// ---------------------------------- MLBaseTest --------------------------------------------------- | |
void ann_check_data(Ptr<TrainData> _data) | |
{ | |
CV_TRACE_FUNCTION(); | |
Mat values = _data->getSamples(); | |
Mat var_idx = _data->getVarIdx(); | |
int nvars = (int)var_idx.total(); | |
if (nvars != 0 && nvars != values.cols) | |
CV_Error(CV_StsBadArg, "var_idx is not supported"); | |
if (!_data->getMissing().empty()) | |
CV_Error(CV_StsBadArg, "missing values are not supported"); | |
} | |
int str_to_ann_train_method(String& str) | |
{ | |
if (!str.compare("BACKPROP")) | |
return ANN_MLP::BACKPROP; | |
if (!str.compare("RPROP")) | |
return ANN_MLP::RPROP; | |
if (!str.compare("ANNEAL")) | |
return ANN_MLP::ANNEAL; | |
CV_Error(CV_StsBadArg, "incorrect ann train method string"); | |
return -1; | |
} | |
int str_to_ann_activation_function(String& str) | |
{ | |
if (!str.compare("IDENTITY")) | |
return ANN_MLP::IDENTITY; | |
if (!str.compare("SIGMOID_SYM")) | |
return ANN_MLP::SIGMOID_SYM; | |
if (!str.compare("GAUSSIAN")) | |
return ANN_MLP::GAUSSIAN; | |
if (!str.compare("RELU")) | |
return ANN_MLP::RELU; | |
if (!str.compare("LEAKYRELU")) | |
return ANN_MLP::LEAKYRELU; | |
CV_Error(CV_StsBadArg, "incorrect ann activation function string"); | |
return -1; | |
} | |
// unroll the categorical responses to binary vectors | |
Mat ann_get_new_responses(Ptr<TrainData> _data, map<int, int>& cls_map) | |
{ | |
CV_TRACE_FUNCTION(); | |
Mat train_sidx = _data->getTrainSampleIdx(); | |
int* train_sidx_ptr = train_sidx.ptr<int>(); | |
Mat responses = _data->getResponses(); | |
int cls_count = 0; | |
// construct cls_map | |
cls_map.clear(); | |
int nresponses = (int)responses.total(); | |
int si, n = !train_sidx.empty() ? (int)train_sidx.total() : nresponses; | |
for (si = 0; si < n; si++) | |
{ | |
int sidx = train_sidx_ptr ? train_sidx_ptr[si] : si; | |
int r = cvRound(responses.at<float>(sidx)); | |
CV_DbgAssert(fabs(responses.at<float>(sidx) - r) < FLT_EPSILON); | |
map<int, int>::iterator it = cls_map.find(r); | |
if (it == cls_map.end()) | |
cls_map[r] = cls_count++; | |
} | |
Mat new_responses = Mat::zeros(nresponses, cls_count, CV_32F); | |
for (si = 0; si < n; si++) | |
{ | |
int sidx = train_sidx_ptr ? train_sidx_ptr[si] : si; | |
int r = cvRound(responses.at<float>(sidx)); | |
int cidx = cls_map[r]; | |
new_responses.at<float>(sidx, cidx) = 1.f; | |
} | |
return new_responses; | |
} | |
float ann_calc_error(Ptr<StatModel> ann, Ptr<TrainData> _data, map<int, int>& cls_map, int type, vector<float> *resp_labels) | |
{ | |
CV_TRACE_FUNCTION(); | |
float err = 0; | |
Mat samples = _data->getSamples(); | |
Mat responses = _data->getResponses(); | |
Mat sample_idx = (type == CV_TEST_ERROR) ? _data->getTestSampleIdx() : _data->getTrainSampleIdx(); | |
int* sidx = !sample_idx.empty() ? sample_idx.ptr<int>() : 0; | |
ann_check_data(_data); | |
int sample_count = (int)sample_idx.total(); | |
sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? samples.rows : sample_count; | |
float* pred_resp = 0; | |
vector<float> innresp; | |
if (sample_count > 0) | |
{ | |
if (resp_labels) | |
{ | |
resp_labels->resize(sample_count); | |
pred_resp = &((*resp_labels)[0]); | |
} | |
else | |
{ | |
innresp.resize(sample_count); | |
pred_resp = &(innresp[0]); | |
} | |
} | |
int cls_count = (int)cls_map.size(); | |
Mat output(1, cls_count, CV_32FC1); | |
for (int i = 0; i < sample_count; i++) | |
{ | |
int si = sidx ? sidx[i] : i; | |
Mat sample = samples.row(si); | |
ann->predict(sample, output); | |
Point best_cls; | |
minMaxLoc(output, 0, 0, 0, &best_cls, 0); | |
int r = cvRound(responses.at<float>(si)); | |
CV_DbgAssert(fabs(responses.at<float>(si) - r) < FLT_EPSILON); | |
r = cls_map[r]; | |
int d = best_cls.x == r ? 0 : 1; | |
err += d; | |
pred_resp[i] = (float)best_cls.x; | |
} | |
err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX; | |
return err; | |
} | |
CV_MLBaseTest::CV_MLBaseTest(const char* _modelName) | |
{ | |
int64 seeds[] = { CV_BIG_INT(0x00009fff4f9c8d52), | |
CV_BIG_INT(0x0000a17166072c7c), | |
CV_BIG_INT(0x0201b32115cd1f9a), | |
CV_BIG_INT(0x0513cb37abcd1234), | |
CV_BIG_INT(0x0001a2b3c4d5f678) | |
}; | |
int seedCount = sizeof(seeds) / sizeof(seeds[0]); | |
RNG& rng = theRNG(); | |
setUseOptimized(false); | |
initSeed = rng.state; | |
rng.state = seeds[rng(seedCount)]; | |
modelName = _modelName; | |
} | |
CV_MLBaseTest::~CV_MLBaseTest() | |
{ | |
if (validationFS.isOpened()) | |
validationFS.release(); | |
theRNG().state = initSeed; | |
} | |
int CV_MLBaseTest::read_params(CvFileStorage* __fs) | |
{ | |
CV_TRACE_FUNCTION(); | |
FileStorage _fs(__fs, false); | |
if (!_fs.isOpened()) | |
test_case_count = -1; | |
else | |
{ | |
FileNode fn = _fs.getFirstTopLevelNode()["run_params"][modelName]; | |
test_case_count = (int)fn.size(); | |
if (test_case_count <= 0) | |
test_case_count = -1; | |
if (test_case_count > 0) | |
{ | |
dataSetNames.resize(test_case_count); | |
FileNodeIterator it = fn.begin(); | |
for (int i = 0; i < test_case_count; i++, ++it) | |
{ | |
dataSetNames[i] = (string)*it; | |
} | |
} | |
} | |
return cvtest::TS::OK;; | |
} | |
void CV_MLBaseTest::run(int) | |
{ | |
CV_TRACE_FUNCTION(); | |
string filename = ts->get_data_path(); | |
filename += get_validation_filename(); | |
validationFS.open(filename, FileStorage::READ); | |
read_params(*validationFS); | |
int code = cvtest::TS::OK; | |
for (int i = 0; i < test_case_count; i++) | |
{ | |
CV_TRACE_REGION("iteration"); | |
int temp_code = run_test_case(i); | |
if (temp_code == cvtest::TS::OK) | |
temp_code = validate_test_results(i); | |
if (temp_code != cvtest::TS::OK) | |
code = temp_code; | |
} | |
if (test_case_count <= 0) | |
{ | |
ts->printf(cvtest::TS::LOG, "validation file is not determined or not correct"); | |
code = cvtest::TS::FAIL_INVALID_TEST_DATA; | |
} | |
ts->set_failed_test_info(code); | |
} | |
int CV_MLBaseTest::prepare_test_case(int test_case_idx) | |
{ | |
CV_TRACE_FUNCTION(); | |
clear(); | |
string dataPath = ts->get_data_path(); | |
if (dataPath.empty()) | |
{ | |
ts->printf(cvtest::TS::LOG, "data path is empty"); | |
return cvtest::TS::FAIL_INVALID_TEST_DATA; | |
} | |
string dataName = dataSetNames[test_case_idx], | |
filename = dataPath + dataName + ".data"; | |
FileNode dataParamsNode = validationFS.getFirstTopLevelNode()["validation"][modelName][dataName]["data_params"]; | |
CV_DbgAssert(!dataParamsNode.empty()); | |
CV_DbgAssert(!dataParamsNode["LS"].empty()); | |
int trainSampleCount = (int)dataParamsNode["LS"]; | |
CV_DbgAssert(!dataParamsNode["resp_idx"].empty()); | |
int respIdx = (int)dataParamsNode["resp_idx"]; | |
CV_DbgAssert(!dataParamsNode["types"].empty()); | |
String varTypes = (String)dataParamsNode["types"]; | |
data = TrainData::loadFromCSV(filename, 0, respIdx, respIdx + 1, varTypes); | |
if (data.empty()) | |
{ | |
ts->printf(cvtest::TS::LOG, "file %s can not be read\n", filename.c_str()); | |
return cvtest::TS::FAIL_INVALID_TEST_DATA; | |
} | |
data->setTrainTestSplit(trainSampleCount); | |
return cvtest::TS::OK; | |
} | |
string& CV_MLBaseTest::get_validation_filename() | |
{ | |
return validationFN; | |
} | |
int CV_MLBaseTest::train(int testCaseIdx) | |
{ | |
CV_TRACE_FUNCTION(); | |
bool is_trained = false; | |
FileNode modelParamsNode = | |
validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"]; | |
if (modelName == CV_ANN) | |
{ | |
String train_method_str, activation_function_str; | |
double param1, param2; | |
modelParamsNode["train_method"] >> train_method_str; | |
modelParamsNode["param1"] >> param1; | |
modelParamsNode["param2"] >> param2; | |
Mat new_responses = ann_get_new_responses(data, cls_map); | |
// binarize the responses | |
data = TrainData::create(data->getSamples(), data->getLayout(), new_responses, | |
data->getVarIdx(), data->getTrainSampleIdx()); | |
int layer_sz[] = { data->getNAllVars(), 100, 100, (int)cls_map.size() }; | |
Mat layer_sizes(1, (int)(sizeof(layer_sz) / sizeof(layer_sz[0])), CV_32S, layer_sz); | |
Ptr<ANN_MLP> m = ANN_MLP::create(); | |
m->setLayerSizes(layer_sizes); | |
m->setActivationFunction(ANN_MLP::SIGMOID_SYM, 0, 0); | |
m->setTermCriteria(TermCriteria(TermCriteria::COUNT, 300, 0.01)); | |
m->setTrainMethod(str_to_ann_train_method(train_method_str), param1, param2); | |
model = m; | |
} | |
if (!model.empty()) | |
is_trained = model->train(data, 0); | |
if (!is_trained) | |
{ | |
ts->printf(cvtest::TS::LOG, "in test case %d model training was failed", testCaseIdx); | |
return cvtest::TS::FAIL_INVALID_OUTPUT; | |
} | |
return cvtest::TS::OK; | |
} | |
float CV_MLBaseTest::get_test_error(int /*testCaseIdx*/, vector<float> *resp) | |
{ | |
CV_TRACE_FUNCTION(); | |
int type = CV_TEST_ERROR; | |
float err = 0; | |
Mat _resp; | |
if (modelName == CV_ANN) | |
err = ann_calc_error(model, data, cls_map, type, resp); | |
if (!_resp.empty() && resp) | |
_resp.convertTo(*resp, CV_32F); | |
return err; | |
} | |
void CV_MLBaseTest::save(const char* filename) | |
{ | |
CV_TRACE_FUNCTION(); | |
model->save(filename); | |
} | |
void CV_MLBaseTest::load(const char* filename) | |
{ | |
CV_TRACE_FUNCTION(); | |
if (modelName == CV_ANN) | |
model = Algorithm::load<ANN_MLP>(filename); | |
else | |
CV_Error(CV_StsNotImplemented, "invalid stat model name"); | |
} | |
class CV_AMLTest : public CV_MLBaseTest | |
{ | |
public: | |
CV_AMLTest(const char* _modelName); | |
virtual ~CV_AMLTest() {} | |
protected: | |
virtual int run_test_case(int testCaseIdx); | |
virtual int validate_test_results(int testCaseIdx); | |
}; | |
class CV_SLMLTest : public CV_MLBaseTest | |
{ | |
public: | |
CV_SLMLTest(const char* _modelName); | |
virtual ~CV_SLMLTest() {} | |
protected: | |
virtual int run_test_case(int testCaseIdx); | |
virtual int validate_test_results(int testCaseIdx); | |
std::vector<float> test_resps1, test_resps2; // predicted responses for test data | |
std::string fname1, fname2; | |
}; | |
using namespace std; | |
using namespace cv; | |
CV_AMLTest::CV_AMLTest(const char* _modelName) : CV_MLBaseTest(_modelName) | |
{ | |
validationFN = "avalidation.xml"; | |
} | |
int CV_AMLTest::run_test_case(int testCaseIdx) | |
{ | |
CV_TRACE_FUNCTION(); | |
int code = cvtest::TS::OK; | |
code = prepare_test_case(testCaseIdx); | |
if (code == cvtest::TS::OK) | |
{ | |
//#define GET_STAT | |
#ifdef GET_STAT | |
const char* data_name = ((CvFileNode*)cvGetSeqElem(dataSetNames, testCaseIdx))->data.str.ptr; | |
printf("%s, %s ", name, data_name); | |
const int icount = 100; | |
float res[icount]; | |
for (int k = 0; k < icount; k++) | |
{ | |
#endif | |
data->shuffleTrainTest(); | |
code = train(testCaseIdx); | |
#ifdef GET_STAT | |
float case_result = get_error(); | |
res[k] = case_result; | |
} | |
float mean = 0, sigma = 0; | |
for (int k = 0; k < icount; k++) | |
{ | |
mean += res[k]; | |
} | |
mean = mean / icount; | |
for (int k = 0; k < icount; k++) | |
{ | |
sigma += (res[k] - mean)*(res[k] - mean); | |
} | |
sigma = sqrt(sigma / icount); | |
printf("%f, %f\n", mean, sigma); | |
#endif | |
} | |
return code; | |
} | |
int CV_AMLTest::validate_test_results(int testCaseIdx) | |
{ | |
CV_TRACE_FUNCTION(); | |
int iters; | |
float mean, sigma; | |
// read validation params | |
FileNode resultNode = | |
validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["result"]; | |
resultNode["iter_count"] >> iters; | |
if (iters > 0) | |
{ | |
resultNode["mean"] >> mean; | |
resultNode["sigma"] >> sigma; | |
model->save(format("/Users/vp/tmp/dtree/testcase_%02d.cur.yml", testCaseIdx)); | |
float curErr = get_test_error(testCaseIdx); | |
const int coeff = 4; | |
ts->printf(cvtest::TS::LOG, "Test case = %d; test error = %f; mean error = %f (diff=%f), %d*sigma = %f\n", | |
testCaseIdx, curErr, mean, abs(curErr - mean), coeff, coeff*sigma); | |
if (abs(curErr - mean) > coeff*sigma) | |
{ | |
ts->printf(cvtest::TS::LOG, "abs(%f - %f) > %f - OUT OF RANGE!\n", curErr, mean, coeff*sigma, coeff); | |
return cvtest::TS::FAIL_BAD_ACCURACY; | |
} | |
else | |
ts->printf(cvtest::TS::LOG, ".\n"); | |
} | |
else | |
{ | |
ts->printf(cvtest::TS::LOG, "validation info is not suitable"); | |
return cvtest::TS::FAIL_INVALID_TEST_DATA; | |
} | |
return cvtest::TS::OK; | |
} | |
/* End of file. */ | |
CV_SLMLTest::CV_SLMLTest(const char* _modelName) : CV_MLBaseTest(_modelName) | |
{ | |
validationFN = "slvalidation.xml"; | |
} | |
int CV_SLMLTest::run_test_case(int testCaseIdx) | |
{ | |
int code = cvtest::TS::OK; | |
code = prepare_test_case(testCaseIdx); | |
if (code == cvtest::TS::OK) | |
{ | |
data->setTrainTestSplit(data->getNTrainSamples(), true); | |
code = train(testCaseIdx); | |
if (code == cvtest::TS::OK) | |
{ | |
get_test_error(testCaseIdx, &test_resps1); | |
fname1 = tempfile(".json.gz"); | |
save((fname1 + "?base64").c_str()); | |
load(fname1.c_str()); | |
get_test_error(testCaseIdx, &test_resps2); | |
fname2 = tempfile(".json.gz"); | |
save((fname2 + "?base64").c_str()); | |
} | |
else | |
ts->printf(cvtest::TS::LOG, "model can not be trained"); | |
} | |
return code; | |
} | |
int CV_SLMLTest::validate_test_results(int testCaseIdx) | |
{ | |
int code = cvtest::TS::OK; | |
// 1. compare files | |
FILE *fs1 = fopen(fname1.c_str(), "rb"), *fs2 = fopen(fname2.c_str(), "rb"); | |
size_t sz1 = 0, sz2 = 0; | |
if (!fs1 || !fs2) | |
code = cvtest::TS::FAIL_MISSING_TEST_DATA; | |
if (code >= 0) | |
{ | |
fseek(fs1, 0, SEEK_END); fseek(fs2, 0, SEEK_END); | |
sz1 = ftell(fs1); | |
sz2 = ftell(fs2); | |
fseek(fs1, 0, SEEK_SET); fseek(fs2, 0, SEEK_SET); | |
} | |
if (sz1 != sz2) | |
code = cvtest::TS::FAIL_INVALID_OUTPUT; | |
if (code >= 0) | |
{ | |
const int BUFSZ = 1024; | |
uchar buf1[BUFSZ], buf2[BUFSZ]; | |
for (size_t pos = 0; pos < sz1; ) | |
{ | |
size_t r1 = fread(buf1, 1, BUFSZ, fs1); | |
size_t r2 = fread(buf2, 1, BUFSZ, fs2); | |
if (r1 != r2 || memcmp(buf1, buf2, r1) != 0) | |
{ | |
ts->printf(cvtest::TS::LOG, | |
"in test case %d first (%s) and second (%s) saved files differ in %d-th kb\n", | |
testCaseIdx, fname1.c_str(), fname2.c_str(), | |
(int)pos); | |
code = cvtest::TS::FAIL_INVALID_OUTPUT; | |
break; | |
} | |
pos += r1; | |
} | |
} | |
if (fs1) | |
fclose(fs1); | |
if (fs2) | |
fclose(fs2); | |
// delete temporary files | |
if (code >= 0) | |
{ | |
remove(fname1.c_str()); | |
remove(fname2.c_str()); | |
} | |
if (code >= 0) | |
{ | |
// 2. compare responses | |
CV_Assert(test_resps1.size() == test_resps2.size()); | |
vector<float>::const_iterator it1 = test_resps1.begin(), it2 = test_resps2.begin(); | |
for (; it1 != test_resps1.end(); ++it1, ++it2) | |
{ | |
if (fabs(*it1 - *it2) > FLT_EPSILON) | |
{ | |
ts->printf(cvtest::TS::LOG, "in test case %d responses predicted before saving and after loading is different", testCaseIdx); | |
code = cvtest::TS::FAIL_INVALID_OUTPUT; | |
break; | |
} | |
} | |
} | |
return code; | |
} | |
TEST(ML_ANN, save_load) { CV_SLMLTest test(CV_ANN); test.safe_run(); } | |
class CV_LegacyTest : public cvtest::BaseTest | |
{ | |
public: | |
CV_LegacyTest(const std::string &_modelName, const std::string &_suffixes = std::string()) | |
: cvtest::BaseTest(), modelName(_modelName), suffixes(_suffixes) | |
{ | |
} | |
virtual ~CV_LegacyTest() {} | |
protected: | |
void run(int) | |
{ | |
unsigned int idx = 0; | |
for (;;) | |
{ | |
if (idx >= suffixes.size()) | |
break; | |
int found = (int)suffixes.find(';', idx); | |
string piece = suffixes.substr(idx, found - idx); | |
if (piece.empty()) | |
break; | |
oneTest(piece); | |
idx += (unsigned int)piece.size() + 1; | |
} | |
} | |
void oneTest(const string & suffix) | |
{ | |
using namespace cv::ml; | |
int code = cvtest::TS::OK; | |
string filename = ts->get_data_path() + "legacy/" + modelName + suffix; | |
bool isTree = false; | |
Ptr<StatModel> model; | |
if (modelName == CV_ANN) | |
model = Algorithm::load<ANN_MLP>(filename); | |
if (!model) | |
{ | |
code = cvtest::TS::FAIL_INVALID_TEST_DATA; | |
} | |
else | |
{ | |
Mat input = Mat(isTree ? 10 : 1, model->getVarCount(), CV_32F); | |
ts->get_rng().fill(input, RNG::UNIFORM, 0, 40); | |
if (isTree) | |
randomFillCategories(filename, input); | |
Mat output; | |
model->predict(input, output, StatModel::RAW_OUTPUT | (isTree ? DTrees::PREDICT_SUM : 0)); | |
// just check if no internal assertions or errors thrown | |
} | |
ts->set_failed_test_info(code); | |
} | |
void randomFillCategories(const string & filename, Mat & input) | |
{ | |
Mat catMap; | |
Mat catCount; | |
std::vector<uchar> varTypes; | |
FileStorage fs(filename, FileStorage::READ); | |
FileNode root = fs.getFirstTopLevelNode(); | |
root["cat_map"] >> catMap; | |
root["cat_count"] >> catCount; | |
root["var_type"] >> varTypes; | |
int offset = 0; | |
int countOffset = 0; | |
uint var = 0, varCount = (uint)varTypes.size(); | |
for (; var < varCount; ++var) | |
{ | |
if (varTypes[var] == ml::VAR_CATEGORICAL) | |
{ | |
int size = catCount.at<int>(0, countOffset); | |
for (int row = 0; row < input.rows; ++row) | |
{ | |
int randomChosenIndex = offset + ((uint)ts->get_rng()) % size; | |
int value = catMap.at<int>(0, randomChosenIndex); | |
input.at<float>(row, var) = (float)value; | |
} | |
offset += size; | |
++countOffset; | |
} | |
} | |
} | |
string modelName; | |
string suffixes; | |
}; | |
//TEST(ML_ANN, legacy_load) { CV_LegacyTest test(CV_ANN, "_waveform.xml"); test.safe_run(); } | |
/*TEST(ML_ANN, ActivationFunction) | |
{ | |
String folder = string(cvtest::TS::ptr()->get_data_path()); | |
String original_path = folder + "waveform.data"; | |
String dataname = folder + "waveform"; | |
Ptr<TrainData> tdata = TrainData::loadFromCSV(original_path, 0); | |
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path; | |
RNG& rng = theRNG(); | |
rng.state = 1027401484159173092; | |
tdata->setTrainTestSplit(500); | |
vector<int> activationType; | |
activationType.push_back(ml::ANN_MLP::IDENTITY); | |
activationType.push_back(ml::ANN_MLP::SIGMOID_SYM); | |
activationType.push_back(ml::ANN_MLP::GAUSSIAN); | |
activationType.push_back(ml::ANN_MLP::RELU); | |
activationType.push_back(ml::ANN_MLP::LEAKYRELU); | |
vector<String> activationName; | |
activationName.push_back("_identity"); | |
activationName.push_back("_sigmoid_sym"); | |
activationName.push_back("_gaussian"); | |
activationName.push_back("_relu"); | |
activationName.push_back("_leakyrelu"); | |
for (size_t i = 0; i < activationType.size(); i++) | |
{ | |
Ptr<ml::ANN_MLP> x = ml::ANN_MLP::create(); | |
Mat_<int> layerSizes(1, 4); | |
layerSizes(0, 0) = tdata->getNVars(); | |
layerSizes(0, 1) = 100; | |
layerSizes(0, 2) = 100; | |
layerSizes(0, 3) = tdata->getResponses().cols; | |
x->setLayerSizes(layerSizes); | |
x->setActivationFunction(activationType[i]); | |
x->setTrainMethod(ml::ANN_MLP::RPROP, 0.01, 0.1); | |
x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 300, 0.01)); | |
x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE); | |
ASSERT_TRUE(x->isTrained()) << "Could not train networks with " << activationName[i]; | |
#ifdef GENERATE_TESTDATA | |
x->save(dataname + activationName[i] + ".yml"); | |
#else | |
Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(dataname + activationName[i] + ".yml"); | |
ASSERT_TRUE(y != NULL) << "Could not load " << dataname + activationName[i] + ".yml"; | |
Mat testSamples = tdata->getTestSamples(); | |
Mat rx, ry, dst; | |
x->predict(testSamples, rx); | |
y->predict(testSamples, ry); | |
double n=cvtest::norm(rx,ry, NORM_INF); | |
ASSERT_TRUE(n<FLT_EPSILON) << "Predict are not equal for " << dataname + activationName[i] + ".yml and " << activationName[i]; | |
#endif | |
} | |
}*/ | |
//#define GENERATE_TESTDATA | |
TEST(ML_ANN, Method) | |
{ | |
String folder = string(cvtest::TS::ptr()->get_data_path()); | |
String original_path = folder + "waveform.data"; | |
String dataname = folder + "waveform"; | |
Ptr<TrainData> tdata2 = TrainData::loadFromCSV(original_path, 0); | |
Mat responses(tdata2->getResponses().rows, 3, CV_32FC1, Scalar(0)); | |
for (int i = 0; i<tdata2->getResponses().rows; i++) | |
responses.at<float>(i, static_cast<int>(tdata2->getResponses().at<float>(i, 0))) = 1; | |
Ptr<TrainData> tdata = TrainData::create(tdata2->getSamples(), ml::ROW_SAMPLE, responses); | |
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path; | |
RNG& rng = theRNG(); | |
rng.state = 1027401484159173092; | |
tdata->setTrainTestSplitRatio(0.01); | |
vector<int> methodType; | |
methodType.push_back(ml::ANN_MLP::RPROP); | |
methodType.push_back(ml::ANN_MLP::ANNEAL); | |
methodType.push_back(ml::ANN_MLP::BACKPROP); | |
vector<String> methodName; | |
methodName.push_back("_rprop"); | |
methodName.push_back("_anneal"); | |
methodName.push_back("_backprop"); | |
#ifdef GENERATE_TESTDATA | |
rng.state = 1027401484159173092; | |
Ptr<ml::ANN_MLP> xx = ml::ANN_MLP::create(); | |
Mat_<int> layerSizesXX(1, 3); | |
layerSizesXX(0, 0) = tdata->getNVars(); | |
layerSizesXX(0, 1) = 30; | |
layerSizesXX(0, 2) = tdata->getResponses().cols; | |
xx->setLayerSizes(layerSizesXX); | |
xx->setActivationFunction(ml::ANN_MLP::SIGMOID_SYM); | |
xx->setTrainMethod(ml::ANN_MLP::RPROP); | |
xx->setTermCriteria(TermCriteria(TermCriteria::COUNT, 1, 0.01)); | |
xx->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE); | |
FileStorage fs; | |
fs.open(dataname + "_init_weight.yml", FileStorage::WRITE + FileStorage::BASE64); | |
// xx->save(dataname + "_init_weight.yml"); | |
xx->write(fs); | |
fs.release(); | |
#endif | |
for (size_t i = 0; i < methodType.size(); i++) | |
{ | |
rng.state = 1027401484159173092; | |
FileStorage fs; | |
fs.open(dataname + "_init_weight.yml", FileStorage::READ + FileStorage::BASE64); | |
Ptr<ml::ANN_MLP> x = ml::ANN_MLP::create(); | |
x->read(fs.root()); | |
x->setTrainMethod(methodType[i]); | |
if (methodType[i] == ml::ANN_MLP::BACKPROP) | |
{ | |
x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 1, 0.01)); | |
for (int jj = 0; jj<10; jj++) | |
{ | |
x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::UPDATE_WEIGHTS); | |
FileStorage fs; | |
fs.open(format("%s_%s_%db64.yml", dataname.c_str(), methodName[i].c_str(), jj), FileStorage::WRITE + FileStorage::BASE64); | |
x->write(fs); | |
fs.release(); | |
x->save(format("%s_%s_%d.yml", dataname.c_str(), methodName[i].c_str(), jj)); | |
} | |
} | |
else | |
{ | |
x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 10, 0.01)); | |
x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::UPDATE_WEIGHTS); | |
} | |
ASSERT_TRUE(x->isTrained()) << "Could not train networks with " << methodName[i]; | |
#ifdef GENERATE_TESTDATA | |
x->save(dataname + methodName[i] + ".yml"); | |
#else | |
Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(dataname + methodName[i] + ".yml"); | |
ASSERT_TRUE(y != NULL) << "Could not load " << dataname + methodName[i] + ".yml"; | |
Mat testSamples = tdata->getTestSamples(); | |
Mat rx, ry, dst; | |
for (int j = 0; j < 4; j++) | |
{ | |
rx= x->getWeights(j); | |
ry= y->getWeights(j); | |
double n = cvtest::norm(rx, ry, NORM_INF); | |
EXPECT_LT(n, FLT_EPSILON) << "Weights are not equal for " << dataname + methodName[i] + ".yml and " << methodName[i]<< " layer : "<<j; | |
} | |
x->predict(testSamples, rx); | |
y->predict(testSamples, ry); | |
double n = cvtest::norm(rx, ry, NORM_INF); | |
EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal for " << dataname + methodName[i] + ".yml and " << methodName[i]; | |
#endif | |
} | |
} | |
/* End of file. */ | |
CV_TEST_MAIN("ml") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment