Skip to content

Instantly share code, notes, and snippets.

@daviddoria
Created January 14, 2017 15:16
Show Gist options
  • Save daviddoria/1725d6dfe1bf6ee2f951eb6ae9b6a973 to your computer and use it in GitHub Desktop.
Save daviddoria/1725d6dfe1bf6ee2f951eb6ae9b6a973 to your computer and use it in GitHub Desktop.
Demonstrate all "one-to-rest" pairs SVM confidence
#include <opencv2/opencv.hpp>
#include <algorithm>
#include <unordered_map>
#include <set>
#include <vector>
const int WIDTH = 512;
const int HEIGHT = 512;
enum class Classes {Class1, Class2, Class3, Class4};
std::map<Classes, cv::Vec3b> colorsv{ {Classes::Class1, cv::Vec3b(255, 0, 0)},
{Classes::Class2, cv::Vec3b(0, 255, 0)},
{Classes::Class3, cv::Vec3b(0, 0, 255)},
{Classes::Class4, cv::Vec3b(0, 255, 255)} };
std::map<Classes, cv::Vec3b> colorsv_shaded{ {Classes::Class1, cv::Vec3b(200, 0, 0)},
{Classes::Class2, cv::Vec3b(0, 200, 0)},
{Classes::Class3, cv::Vec3b(0, 0, 200)},
{Classes::Class4, cv::Vec3b(0, 200, 200)} };
std::map<Classes, int> classIDs { {Classes::Class1, 1},
{Classes::Class2, 2},
{Classes::Class3, 3},
{Classes::Class4, 4}};
std::map<int, Classes> classTypes { {1, Classes::Class1},
{2, Classes::Class2},
{3, Classes::Class3},
{4, Classes::Class4}};
// Find the max value (and the corresponding key) in a map
template<typename KeyType, typename ValueType>
std::pair<KeyType,ValueType> get_max( const std::map<KeyType,ValueType>& x ) {
using pairtype=std::pair<KeyType,ValueType>;
return *std::max_element(x.begin(), x.end(), [] (const pairtype & p1, const pairtype & p2) {
return p1.second < p2.second;
});
}
std::set<int> uniqueValues(const cv::Mat1i& mat)
{
std::set<int> uniqueValues;
for(int i = 0; i < mat.rows; i++)
{
for(int j=0; j < mat.cols; j++)
{
uniqueValues.insert(mat(i,j));
}
}
return uniqueValues;
}
void showRegions(const std::map<Classes, cv::Ptr<cv::ml::SVM>>& svms)
{
using SVMMapType = std::map<Classes, cv::Ptr<cv::ml::SVM>>;
cv::Mat3b regions(HEIGHT, WIDTH);
cv::Mat1f R(HEIGHT, WIDTH);
cv::Mat1f R1(HEIGHT, WIDTH);
cv::Mat1f R2(HEIGHT, WIDTH);
cv::Mat1f R3(HEIGHT, WIDTH);
cv::Mat1f R4(HEIGHT, WIDTH);
for (int r = 0; r < HEIGHT; ++r)
{
for (int c = 0; c < WIDTH; ++c)
{
cv::Mat1f sample = (cv::Mat1f(1,2) << c, r);
std::map<Classes, float> responses;
for(SVMMapType::const_iterator svmIterator = svms.begin(); svmIterator != svms.end(); ++svmIterator) {
responses[svmIterator->first] = svmIterator->second->predict(sample, cv::noArray(), cv::ml::StatModel::RAW_OUTPUT);
}
auto maxKeyValue = get_max(responses);
float best_response = maxKeyValue.second;
if (best_response >= 0) {
regions(r, c) = colorsv[maxKeyValue.first];
}
else {
regions(r, c) = colorsv_shaded[maxKeyValue.first];
}
}
}
cv::imshow("Regions", regions);
cv::waitKey();
}
std::pair<cv::Mat1f, cv::Mat1i> createTrainingData()
{
const int N_SAMPLES_PER_CLASS = 10;
const float NON_LINEAR_SAMPLES_RATIO = 0.1;
int N_NON_LINEAR_SAMPLES = N_SAMPLES_PER_CLASS * NON_LINEAR_SAMPLES_RATIO;
int N_LINEAR_SAMPLES = N_SAMPLES_PER_CLASS - N_NON_LINEAR_SAMPLES;
cv::Mat1f data(4 * N_SAMPLES_PER_CLASS, 2);
cv::Mat1i labels(4 * N_SAMPLES_PER_CLASS, 1);
cv::RNG rng(0);
////////////////////////
// Set training data
////////////////////////
// Class 1
cv::Mat1f class1 = data.rowRange(0, 0.5 * N_LINEAR_SAMPLES);
cv::Mat1f x1 = class1.colRange(0, 1);
cv::Mat1f y1 = class1.colRange(1, 2);
rng.fill(x1, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(WIDTH));
rng.fill(y1, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(HEIGHT / 8));
class1 = data.rowRange(0.5 * N_LINEAR_SAMPLES, 1 * N_LINEAR_SAMPLES);
x1 = class1.colRange(0, 1);
y1 = class1.colRange(1, 2);
rng.fill(x1, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(WIDTH));
rng.fill(y1, cv::RNG::UNIFORM, cv::Scalar(7*HEIGHT / 8), cv::Scalar(HEIGHT));
class1 = data.rowRange(N_LINEAR_SAMPLES, 1 * N_SAMPLES_PER_CLASS);
x1 = class1.colRange(0, 1);
y1 = class1.colRange(1, 2);
rng.fill(x1, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(WIDTH));
rng.fill(y1, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(HEIGHT));
// Class 2
cv::Mat1f class2 = data.rowRange(N_SAMPLES_PER_CLASS, N_SAMPLES_PER_CLASS + N_LINEAR_SAMPLES);
cv::Mat1f x2 = class2.colRange(0, 1);
cv::Mat1f y2 = class2.colRange(1, 2);
rng.fill(x2, cv::RNG::NORMAL, cv::Scalar(3 * WIDTH / 4), cv::Scalar(WIDTH/16));
rng.fill(y2, cv::RNG::NORMAL, cv::Scalar(HEIGHT / 2), cv::Scalar(HEIGHT/4));
class2 = data.rowRange(N_SAMPLES_PER_CLASS + N_LINEAR_SAMPLES, 2 * N_SAMPLES_PER_CLASS);
x2 = class2.colRange(0, 1);
y2 = class2.colRange(1, 2);
rng.fill(x2, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(WIDTH));
rng.fill(y2, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(HEIGHT));
// Class 3
cv::Mat1f class3 = data.rowRange(2 * N_SAMPLES_PER_CLASS, 2 * N_SAMPLES_PER_CLASS + N_LINEAR_SAMPLES);
cv::Mat1f x3 = class3.colRange(0, 1);
cv::Mat1f y3 = class3.colRange(1, 2);
rng.fill(x3, cv::RNG::NORMAL, cv::Scalar(WIDTH / 4), cv::Scalar(WIDTH/8));
rng.fill(y3, cv::RNG::NORMAL, cv::Scalar(HEIGHT / 2), cv::Scalar(HEIGHT/8));
class3 = data.rowRange(2*N_SAMPLES_PER_CLASS + N_LINEAR_SAMPLES, 3 * N_SAMPLES_PER_CLASS);
x3 = class3.colRange(0, 1);
y3 = class3.colRange(1, 2);
rng.fill(x3, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(WIDTH));
rng.fill(y3, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(HEIGHT));
// Class 4
cv::Mat1f class4 = data.rowRange(3 * N_SAMPLES_PER_CLASS, 3 * N_SAMPLES_PER_CLASS + 0.5 * N_LINEAR_SAMPLES);
cv::Mat1f x4 = class4.colRange(0, 1);
cv::Mat1f y4 = class4.colRange(1, 2);
rng.fill(x4, cv::RNG::NORMAL, cv::Scalar(WIDTH / 2), cv::Scalar(WIDTH / 16));
rng.fill(y4, cv::RNG::NORMAL, cv::Scalar(HEIGHT / 4), cv::Scalar(HEIGHT / 16));
class4 = data.rowRange(3 * N_SAMPLES_PER_CLASS + 0.5 * N_LINEAR_SAMPLES, 3 * N_SAMPLES_PER_CLASS + N_LINEAR_SAMPLES);
x4 = class4.colRange(0, 1);
y4 = class4.colRange(1, 2);
rng.fill(x4, cv::RNG::NORMAL, cv::Scalar(WIDTH / 2), cv::Scalar(WIDTH / 16));
rng.fill(y4, cv::RNG::NORMAL, cv::Scalar(3 * HEIGHT / 4), cv::Scalar(HEIGHT / 16));
class4 = data.rowRange(3 * N_SAMPLES_PER_CLASS + N_LINEAR_SAMPLES, 4 * N_SAMPLES_PER_CLASS);
x4 = class4.colRange(0, 1);
y4 = class4.colRange(1, 2);
rng.fill(x4, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(WIDTH));
rng.fill(y4, cv::RNG::UNIFORM, cv::Scalar(1), cv::Scalar(HEIGHT));
// Labels
labels.rowRange(0*N_SAMPLES_PER_CLASS, 1*N_SAMPLES_PER_CLASS).setTo(classIDs[Classes::Class1]);
labels.rowRange(1*N_SAMPLES_PER_CLASS, 2*N_SAMPLES_PER_CLASS).setTo(classIDs[Classes::Class2]);
labels.rowRange(2*N_SAMPLES_PER_CLASS, 3*N_SAMPLES_PER_CLASS).setTo(classIDs[Classes::Class3]);
labels.rowRange(3*N_SAMPLES_PER_CLASS, 4*N_SAMPLES_PER_CLASS).setTo(classIDs[Classes::Class4]);
return std::make_pair(data, labels);
}
std::map<Classes, cv::Ptr<cv::ml::SVM>> trainAllOneToOthersSVMs(const std::pair<cv::Mat1f, cv::Mat1i>& trainingData)
{
cv::Mat1f data = trainingData.first;
cv::Mat1i labels = trainingData.second;
//const int KERNEL = SVM::CHI2;
const int KERNEL = cv::ml::SVM::INTER;
//const bool AUTO_TRAIN_ENABLED = false;
const bool AUTO_TRAIN_ENABLED = true;
std::set<int> classes = uniqueValues(labels);
std::map<Classes, cv::Ptr<cv::ml::SVM>> svms;
using iteratorType = std::set<int>::const_iterator;
for(iteratorType iterator = classes.begin(); iterator != classes.end(); ++iterator) {
Classes currentClassType = classTypes[*iterator];
cv::Ptr<cv::ml::SVM> currentSVM = cv::ml::SVM::create();
svms[currentClassType] = currentSVM;
currentSVM->setType(cv::ml::SVM::C_SVC);
currentSVM->setKernel(KERNEL);
cv::Mat1i currentLabels = (labels != classIDs[currentClassType]) / 255;
if (AUTO_TRAIN_ENABLED)
{
cv::Ptr<cv::ml::TrainData> currentTrainingData = cv::ml::TrainData::create(data, cv::ml::ROW_SAMPLE, currentLabels);
currentSVM->trainAuto(currentTrainingData);
}
else
{
currentSVM->setC(0.1);
currentSVM->setGamma(0.001);
currentSVM->setTermCriteria(cv::TermCriteria(cv::TermCriteria::MAX_ITER, (int)1e7, 1e-6));
currentSVM->train(data, cv::ml::ROW_SAMPLE, currentLabels);
}
}
return svms;
}
int main()
{
std::pair<cv::Mat1f, cv::Mat1i> trainingData = createTrainingData();
std::map<Classes, cv::Ptr<cv::ml::SVM>> svms = trainAllOneToOthersSVMs(trainingData);
showRegions(svms);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment