Last active
August 29, 2015 14:18
-
-
Save ubnt-intrepid/e7d74e98d7beb1513f20 to your computer and use it in GitHub Desktop.
implementation of sampler from Chinese Restaurant Process (CRP) in C++11
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <algorithm> | |
#include <ctime> | |
#include <fstream> | |
#include <iostream> | |
#include <iterator> | |
#include <random> | |
#include <tuple> | |
#include <utility> | |
#include <vector> | |
#include <boost/progress.hpp> | |
using namespace std; | |
template <typename T> | |
std::ostream& operator<<(std::ostream& os, std::vector<T> const& vec) | |
{ | |
os << "[ "; | |
for_each(vec.begin(), vec.end(), [&](T v){ os << v << " "; }); | |
os << "]"; | |
return os; | |
} | |
template <class Engine> | |
size_t sample_from_crp(Engine& engine, double alpha, vector<size_t>& nCustomers) | |
{ | |
vector<double> probs(nCustomers.begin(), nCustomers.end()); | |
probs.push_back(alpha); | |
discrete_distribution<size_t> dist(probs.begin(), probs.end()); | |
size_t k = dist(engine); | |
if (k == nCustomers.size()) | |
nCustomers.push_back(1);// add new table | |
else | |
nCustomers[k] += 1; // increment a number of custemers in k-th table. | |
return k; | |
} | |
template <class Engine> | |
pair<vector<size_t>, size_t> sample_from_crp(Engine& engine, double alpha, size_t num) | |
{ | |
vector<size_t> ret(num); | |
vector<size_t> tables; | |
for (size_t n = 0; n < num; ++n) | |
ret[n] = sample_from_crp(engine, alpha, tables); | |
return make_pair(ret, tables.size()); | |
} | |
int main(int argc, char const* argv[]) | |
{ | |
constexpr double alpha = 10; | |
constexpr size_t N = 100000; | |
mt19937_64 engine(time(nullptr)); | |
// s1, ..., sN ~ CRP(alpha) | |
vector<size_t> ret; | |
size_t num_tables; | |
{ | |
boost::progress_timer t; | |
tie(ret, num_tables) = sample_from_crp(engine, alpha, N); | |
} | |
// accumulate customers of each tables. | |
vector<size_t> acc(num_tables, 0); | |
for (size_t k : ret) | |
acc[k] += 1; | |
{ | |
ofstream fs("load_result.m"); | |
fs << "num_tables = " << num_tables << ';'<<endl; | |
fs << "ret = " << ret << ';' << endl; | |
fs << "acc = " << acc << ';' << endl; | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
load_result | |
h = bar(acc, 'histc'); grid on | |
set(gca, 'xlim', [1 length(acc)]) | |
set(h, 'facecolor','g'); | |
xlabel('table id') | |
ylabel('prob') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment