Created
January 14, 2017 15:16
-
-
Save daviddoria/1725d6dfe1bf6ee2f951eb6ae9b6a973 to your computer and use it in GitHub Desktop.
Demonstrate all "one-to-rest" pairs SVM confidence
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <opencv2/opencv.hpp> | |
#include <algorithm> | |
#include <unordered_map> | |
#include <set> | |
#include <vector> | |
const int WIDTH = 512; | |
const int HEIGHT = 512; | |
enum class Classes {Class1, Class2, Class3, Class4}; | |
std::map<Classes, cv::Vec3b> colorsv{ {Classes::Class1, cv::Vec3b(255, 0, 0)}, | |
{Classes::Class2, cv::Vec3b(0, 255, 0)}, | |
{Classes::Class3, cv::Vec3b(0, 0, 255)}, | |
{Classes::Class4, cv::Vec3b(0, 255, 255)} }; | |
std::map<Classes, cv::Vec3b> colorsv_shaded{ {Classes::Class1, cv::Vec3b(200, 0, 0)}, | |
{Classes::Class2, cv::Vec3b(0, 200, 0)}, | |
{Classes::Class3, cv::Vec3b(0, 0, 200)}, | |
{Classes::Class4, cv::Vec3b(0, 200, 200)} }; | |
std::map<Classes, int> classIDs { {Classes::Class1, 1}, | |
{Classes::Class2, 2}, | |
{Classes::Class3, 3}, | |
{Classes::Class4, 4}}; | |
std::map<int, Classes> classTypes { {1, Classes::Class1}, | |
{2, Classes::Class2}, | |
{3, Classes::Class3}, | |
{4, Classes::Class4}}; | |
// Find the max value (and the corresponding key) in a map | |
template<typename KeyType, typename ValueType> | |
std::pair<KeyType,ValueType> get_max( const std::map<KeyType,ValueType>& x ) { | |
using pairtype=std::pair<KeyType,ValueType>; | |
return *std::max_element(x.begin(), x.end(), [] (const pairtype & p1, const pairtype & p2) { | |
return p1.second < p2.second; | |
}); | |
} | |
std::set<int> uniqueValues(const cv::Mat1i& mat) | |
{ | |
std::set<int> uniqueValues; | |
for(int i = 0; i < mat.rows; i++) | |
{ | |
for(int j=0; j < mat.cols; j++) | |
{ | |
uniqueValues.insert(mat(i,j)); | |
} | |
} | |
return uniqueValues; | |
} | |
void showRegions(const std::map<Classes, cv::Ptr<cv::ml::SVM>>& svms) | |
{ | |
using SVMMapType = std::map<Classes, cv::Ptr<cv::ml::SVM>>; | |
cv::Mat3b regions(HEIGHT, WIDTH); | |
cv::Mat1f R(HEIGHT, WIDTH); | |
cv::Mat1f R1(HEIGHT, WIDTH); | |
cv::Mat1f R2(HEIGHT, WIDTH); | |
cv::Mat1f R3(HEIGHT, WIDTH); | |
cv::Mat1f R4(HEIGHT, WIDTH); | |
for (int r = 0; r < HEIGHT; ++r) | |
{ | |
for (int c = 0; c < WIDTH; ++c) | |
{ | |
cv::Mat1f sample = (cv::Mat1f(1,2) << c, r); | |
std::map<Classes, float> responses; | |
for(SVMMapType::const_iterator svmIterator = svms.begin(); svmIterator != svms.end(); ++svmIterator) { | |
responses[svmIterator->first] = svmIterator->second->predict(sample, cv::noArray(), cv::ml::StatModel::RAW_OUTPUT); | |
} | |
auto maxKeyValue = get_max(responses); | |
float best_response = maxKeyValue.second; | |
if (best_response >= 0) { | |
regions(r, c) = colorsv[maxKeyValue.first]; | |
} | |
else { | |
regions(r, c) = colorsv_shaded[maxKeyValue.first]; | |
} | |
} | |
} | |
cv::imshow("Regions", regions); | |
cv::waitKey(); | |
} | |
std::pair<cv::Mat1f, cv::Mat1i> createTrainingData() | |
{ | |
const int N_SAMPLES_PER_CLASS = 10; | |
const float NON_LINEAR_SAMPLES_RATIO = 0.1; | |
int N_NON_LINEAR_SAMPLES = N_SAMPLES_PER_CLASS * NON_LINEAR_SAMPLES_RATIO; | |
int N_LINEAR_SAMPLES = N_SAMPLES_PER_CLASS - N_NON_LINEAR_SAMPLES; | |
cv::Mat1f data(4 * N_SAMPLES_PER_CLASS, 2); | |
cv::Mat1i labels(4 * N_SAMPLES_PER_CLASS, 1); | |
cv::RNG rng(0); | |
//////////////////////// | |
// Set training data | |
//////////////////////// | |
// Class 1 | |
cv::Mat1f class1 = data.rowRange(0, 0.5 * N_LINEAR_SAMPLES); | |
cv::Mat1f x1 = class1.colRange(0, 1); | |
cv::Mat1f y1 = class1.colRange(1, 2); | |
rng.fill(x1, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(WIDTH)); | |
rng.fill(y1, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(HEIGHT / 8)); | |
class1 = data.rowRange(0.5 * N_LINEAR_SAMPLES, 1 * N_LINEAR_SAMPLES); | |
x1 = class1.colRange(0, 1); | |
y1 = class1.colRange(1, 2); | |
rng.fill(x1, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(WIDTH)); | |
rng.fill(y1, cv::RNG::UNIFORM, cv::Scalar(7*HEIGHT / 8), cv::Scalar(HEIGHT)); | |
class1 = data.rowRange(N_LINEAR_SAMPLES, 1 * N_SAMPLES_PER_CLASS); | |
x1 = class1.colRange(0, 1); | |
y1 = class1.colRange(1, 2); | |
rng.fill(x1, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(WIDTH)); | |
rng.fill(y1, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(HEIGHT)); | |
// Class 2 | |
cv::Mat1f class2 = data.rowRange(N_SAMPLES_PER_CLASS, N_SAMPLES_PER_CLASS + N_LINEAR_SAMPLES); | |
cv::Mat1f x2 = class2.colRange(0, 1); | |
cv::Mat1f y2 = class2.colRange(1, 2); | |
rng.fill(x2, cv::RNG::NORMAL, cv::Scalar(3 * WIDTH / 4), cv::Scalar(WIDTH/16)); | |
rng.fill(y2, cv::RNG::NORMAL, cv::Scalar(HEIGHT / 2), cv::Scalar(HEIGHT/4)); | |
class2 = data.rowRange(N_SAMPLES_PER_CLASS + N_LINEAR_SAMPLES, 2 * N_SAMPLES_PER_CLASS); | |
x2 = class2.colRange(0, 1); | |
y2 = class2.colRange(1, 2); | |
rng.fill(x2, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(WIDTH)); | |
rng.fill(y2, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(HEIGHT)); | |
// Class 3 | |
cv::Mat1f class3 = data.rowRange(2 * N_SAMPLES_PER_CLASS, 2 * N_SAMPLES_PER_CLASS + N_LINEAR_SAMPLES); | |
cv::Mat1f x3 = class3.colRange(0, 1); | |
cv::Mat1f y3 = class3.colRange(1, 2); | |
rng.fill(x3, cv::RNG::NORMAL, cv::Scalar(WIDTH / 4), cv::Scalar(WIDTH/8)); | |
rng.fill(y3, cv::RNG::NORMAL, cv::Scalar(HEIGHT / 2), cv::Scalar(HEIGHT/8)); | |
class3 = data.rowRange(2*N_SAMPLES_PER_CLASS + N_LINEAR_SAMPLES, 3 * N_SAMPLES_PER_CLASS); | |
x3 = class3.colRange(0, 1); | |
y3 = class3.colRange(1, 2); | |
rng.fill(x3, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(WIDTH)); | |
rng.fill(y3, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(HEIGHT)); | |
// Class 4 | |
cv::Mat1f class4 = data.rowRange(3 * N_SAMPLES_PER_CLASS, 3 * N_SAMPLES_PER_CLASS + 0.5 * N_LINEAR_SAMPLES); | |
cv::Mat1f x4 = class4.colRange(0, 1); | |
cv::Mat1f y4 = class4.colRange(1, 2); | |
rng.fill(x4, cv::RNG::NORMAL, cv::Scalar(WIDTH / 2), cv::Scalar(WIDTH / 16)); | |
rng.fill(y4, cv::RNG::NORMAL, cv::Scalar(HEIGHT / 4), cv::Scalar(HEIGHT / 16)); | |
class4 = data.rowRange(3 * N_SAMPLES_PER_CLASS + 0.5 * N_LINEAR_SAMPLES, 3 * N_SAMPLES_PER_CLASS + N_LINEAR_SAMPLES); | |
x4 = class4.colRange(0, 1); | |
y4 = class4.colRange(1, 2); | |
rng.fill(x4, cv::RNG::NORMAL, cv::Scalar(WIDTH / 2), cv::Scalar(WIDTH / 16)); | |
rng.fill(y4, cv::RNG::NORMAL, cv::Scalar(3 * HEIGHT / 4), cv::Scalar(HEIGHT / 16)); | |
class4 = data.rowRange(3 * N_SAMPLES_PER_CLASS + N_LINEAR_SAMPLES, 4 * N_SAMPLES_PER_CLASS); | |
x4 = class4.colRange(0, 1); | |
y4 = class4.colRange(1, 2); | |
rng.fill(x4, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(WIDTH)); | |
rng.fill(y4, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(HEIGHT)); | |
// Labels | |
labels.rowRange(0*N_SAMPLES_PER_CLASS, 1*N_SAMPLES_PER_CLASS).setTo(classIDs[Classes::Class1]); | |
labels.rowRange(1*N_SAMPLES_PER_CLASS, 2*N_SAMPLES_PER_CLASS).setTo(classIDs[Classes::Class2]); | |
labels.rowRange(2*N_SAMPLES_PER_CLASS, 3*N_SAMPLES_PER_CLASS).setTo(classIDs[Classes::Class3]); | |
labels.rowRange(3*N_SAMPLES_PER_CLASS, 4*N_SAMPLES_PER_CLASS).setTo(classIDs[Classes::Class4]); | |
return std::make_pair(data, labels); | |
} | |
std::map<Classes, cv::Ptr<cv::ml::SVM>> trainAllOneToOthersSVMs(const std::pair<cv::Mat1f, cv::Mat1i>& trainingData) | |
{ | |
cv::Mat1f data = trainingData.first; | |
cv::Mat1i labels = trainingData.second; | |
//const int KERNEL = SVM::CHI2; | |
const int KERNEL = cv::ml::SVM::INTER; | |
//const bool AUTO_TRAIN_ENABLED = false; | |
const bool AUTO_TRAIN_ENABLED = true; | |
std::set<int> classes = uniqueValues(labels); | |
std::map<Classes, cv::Ptr<cv::ml::SVM>> svms; | |
using iteratorType = std::set<int>::const_iterator; | |
for(iteratorType iterator = classes.begin(); iterator != classes.end(); ++iterator) { | |
Classes currentClassType = classTypes[*iterator]; | |
cv::Ptr<cv::ml::SVM> currentSVM = cv::ml::SVM::create(); | |
svms[currentClassType] = currentSVM; | |
currentSVM->setType(cv::ml::SVM::C_SVC); | |
currentSVM->setKernel(KERNEL); | |
cv::Mat1i currentLabels = (labels != classIDs[currentClassType]) / 255; | |
if (AUTO_TRAIN_ENABLED) | |
{ | |
cv::Ptr<cv::ml::TrainData> currentTrainingData = cv::ml::TrainData::create(data, cv::ml::ROW_SAMPLE, currentLabels); | |
currentSVM->trainAuto(currentTrainingData); | |
} | |
else | |
{ | |
currentSVM->setC(0.1); | |
currentSVM->setGamma(0.001); | |
currentSVM->setTermCriteria(cv::TermCriteria(cv::TermCriteria::MAX_ITER, (int)1e7, 1e-6)); | |
currentSVM->train(data, cv::ml::ROW_SAMPLE, currentLabels); | |
} | |
} | |
return svms; | |
} | |
int main() | |
{ | |
std::pair<cv::Mat1f, cv::Mat1i> trainingData = createTrainingData(); | |
std::map<Classes, cv::Ptr<cv::ml::SVM>> svms = trainAllOneToOthersSVMs(trainingData); | |
showRegions(svms); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment