Created
February 4, 2018 15:38
-
-
Save y3nr1ng/d6e63c2d08611ecd3474e74308382af7 to your computer and use it in GitHub Desktop.
Simple K-mean implementation.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <fstream> | |
#include <sstream> | |
#include <vector> | |
#include <chrono> | |
#include <random> | |
#include <algorithm> | |
#include <cmath> | |
#define DATASET_FILENAME "iris.data" | |
#define TOLERANCE 1e-3 | |
#define MAX_ITER 1e3 | |
struct Data { | |
int label; | |
float length; | |
float width; | |
Data() { | |
label = -1; | |
length = width = 0.0f; | |
} | |
Data(float _length, float _width) { | |
label = -1; | |
length = _length; | |
width = _width; | |
} | |
}; | |
float calc_dist2(const Data data_a, const Data data_b) { | |
return (data_a.length-data_b.length)*(data_a.length-data_b.length) + | |
(data_a.width-data_b.width)*(data_a.width-data_b.width); | |
} | |
float calc_tol( | |
const std::vector<Data>& curr_cents, | |
const std::vector<Data>& next_cents | |
) { | |
const std::size_t n = curr_cents.size(); | |
float dist2_sum = 0.0f; | |
for (auto i = 0; i < n; i++) { | |
dist2_sum += calc_dist2(curr_cents[i], next_cents[i]); | |
} | |
return dist2_sum / n; | |
} | |
int find_label(const Data data, const std::vector<Data>& cents) { | |
const std::size_t n = cents.size(); | |
float min_dist2 = 0.0f; | |
int min_index = -1; | |
for (auto i = 0; i < n; i++) { | |
float dist2 = calc_dist2(data, cents[i]); | |
if (min_dist2 > dist2 or min_index < 0) { | |
min_dist2 = dist2; | |
min_index = i; | |
} | |
} | |
return min_index; | |
} | |
void update_centroid( | |
const std::vector<Data>& dataset, | |
std::vector<Data>& centroids | |
) { | |
const std::size_t n_cents = centroids.size(); | |
for (auto i = 0; i < n_cents; i++) { | |
centroids[i].label = 0; | |
centroids[i].length = centroids[i].width = 0.0f; | |
} | |
const std::size_t n_dataset = dataset.size(); | |
for (auto i = 0; i < n_dataset; i++) { | |
const Data& data = dataset[i]; | |
centroids[data.label].label++; | |
centroids[data.label].length += data.length; | |
centroids[data.label].width += data.width; | |
} | |
for (auto i = 0; i < n_cents; i++) { | |
centroids[i].length /= centroids[i].label; | |
centroids[i].width /= centroids[i].label; | |
} | |
} | |
float kmean( | |
std::vector<Data>& dataset, | |
const std::vector<Data>& curr_cents, std::vector<Data>& next_cents | |
) { | |
const std::size_t n = dataset.size(); | |
for (auto i = 0; i < n; i++) { | |
dataset[i].label = find_label(dataset[i], curr_cents); | |
} | |
update_centroid(dataset, next_cents); | |
return calc_tol(curr_cents, next_cents); | |
} | |
void init_centroids(const std::vector<Data>& dataset, std::vector<Data>& cents) { | |
float min_length = -1, max_length = -1; | |
float min_width = -1, max_width = -1; | |
const std::size_t n_dataset = dataset.size(); | |
for (auto i = 0; i < n_dataset; i++) { | |
if (min_length < 0 or min_length > dataset[i].length) { | |
min_length = dataset[i].length; | |
} | |
if (max_length < 0 or max_length < dataset[i].length) { | |
max_length = dataset[i].length; | |
} | |
if (min_width < 0 or min_width > dataset[i].width) { | |
min_width = dataset[i].width; | |
} | |
if (max_width < 0 or max_width < dataset[i].width) { | |
max_width = dataset[i].width; | |
} | |
} | |
std::cerr << std::endl; | |
std::cerr << "min(length) = " << min_length << ", max(length) = " << max_length << std::endl; | |
std::cerr << "min(width) = " << min_width << ", max(width) = " << max_width << std::endl; | |
std::cerr << std::endl; | |
std::cerr << "Generating centroids" << std::endl; | |
auto seed = std::chrono::high_resolution_clock::now().time_since_epoch().count(); | |
std::mt19937 rng(seed); | |
std::uniform_real_distribution<float> length_gen(min_length, max_length); | |
std::uniform_real_distribution<float> width_gen(min_width, max_width); | |
const std::size_t n_cents = cents.size(); | |
for (auto i = 0; i < n_cents; i++) { | |
cents[i].label = i; | |
cents[i].length = length_gen(rng); | |
cents[i].width = width_gen(rng); | |
std::cerr << i << ", ("; | |
std::cerr << cents[i].length << ", " << cents[i].width << ")" << std::endl; | |
} | |
} | |
std::vector<Data> kmean(std::vector<Data>& dataset, int n_groups = 3) { | |
std::vector<Data> curr_cents(n_groups), next_cents(n_groups); | |
init_centroids(dataset, curr_cents); | |
float error; | |
for (auto i = 0; ; i++) { | |
error = kmean(dataset, curr_cents, next_cents); | |
std::swap(curr_cents, next_cents); | |
std::cerr << std::endl; | |
std::cerr << "iter " << i << ", error = " << error << std::endl; | |
// early stop conditions | |
if (error < TOLERANCE or i >= MAX_ITER) { | |
break; | |
} | |
// reset the computation if NaN reached | |
if (std::isnan(error)) { | |
std::cerr << ".. reset" << std::endl; | |
init_centroids(dataset, curr_cents); | |
i = 0; | |
} | |
} | |
return curr_cents; | |
} | |
struct InvalidChar { | |
bool operator()(char c) const { | |
return !isprint(static_cast<unsigned char>(c)) && | |
!isblank(static_cast<unsigned char>(c)); | |
} | |
}; | |
void read_from_file(std::ifstream& infile, std::vector<Data>& dataset) { | |
std::string species; | |
float length, width; | |
std::cerr << std::endl; | |
std::cerr << "Reading from file" << std::endl; | |
std::string line; | |
std::istringstream iss; | |
while(std::getline(infile, line, '\n')) { | |
// remove non-ASCII characters | |
line.erase(std::remove_if(line.begin(), line.end(), InvalidChar()), line.end()); | |
iss.clear(); | |
iss.str(line); | |
iss >> species >> length >> width; | |
dataset.emplace_back(length, width); | |
} | |
std::cerr << dataset.size() << " entries read" << std::endl; | |
} | |
int main(void) { | |
std::ifstream infile(DATASET_FILENAME); | |
std::vector<Data> dataset, centroids; | |
read_from_file(infile, dataset); | |
centroids = kmean(dataset, 3); | |
// using label field to do the accounting | |
const std::size_t n_dataset = dataset.size(); | |
const std::size_t n_cents = centroids.size(); | |
for (auto i = 0; i < n_cents; i++) { | |
centroids[i].label = 0; | |
} | |
for (auto i = 0; i < n_dataset; i++) { | |
const Data& data = dataset[i]; | |
centroids[data.label].label++; | |
} | |
std::cout << std::endl; | |
for (auto i = 0; i < n_cents; i++) { | |
std::cout << "(" << centroids[i].length << ", " << centroids[i].width << "), "; | |
std::cout << "n = " << centroids[i].label << std::endl; | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment