Skip to content

Instantly share code, notes, and snippets.

@ubnt-intrepid
Last active August 29, 2015 14:18
Show Gist options
  • Save ubnt-intrepid/e7d74e98d7beb1513f20 to your computer and use it in GitHub Desktop.
Save ubnt-intrepid/e7d74e98d7beb1513f20 to your computer and use it in GitHub Desktop.
implementation of sampler from Chinese Restaurant Process (CRP) in C++11
#include <algorithm>
#include <ctime>
#include <fstream>
#include <iostream>
#include <iterator>
#include <random>
#include <tuple>
#include <utility>
#include <vector>
#include <boost/progress.hpp>
using namespace std;
template <typename T>
std::ostream& operator<<(std::ostream& os, std::vector<T> const& vec)
{
os << "[ ";
for_each(vec.begin(), vec.end(), [&](T v){ os << v << " "; });
os << "]";
return os;
}
template <class Engine>
size_t sample_from_crp(Engine& engine, double alpha, vector<size_t>& nCustomers)
{
vector<double> probs(nCustomers.begin(), nCustomers.end());
probs.push_back(alpha);
discrete_distribution<size_t> dist(probs.begin(), probs.end());
size_t k = dist(engine);
if (k == nCustomers.size())
nCustomers.push_back(1);// add new table
else
nCustomers[k] += 1; // increment a number of custemers in k-th table.
return k;
}
template <class Engine>
pair<vector<size_t>, size_t> sample_from_crp(Engine& engine, double alpha, size_t num)
{
vector<size_t> ret(num);
vector<size_t> tables;
for (size_t n = 0; n < num; ++n)
ret[n] = sample_from_crp(engine, alpha, tables);
return make_pair(ret, tables.size());
}
int main(int argc, char const* argv[])
{
constexpr double alpha = 10;
constexpr size_t N = 100000;
mt19937_64 engine(time(nullptr));
// s1, ..., sN ~ CRP(alpha)
vector<size_t> ret;
size_t num_tables;
{
boost::progress_timer t;
tie(ret, num_tables) = sample_from_crp(engine, alpha, N);
}
// accumulate customers of each tables.
vector<size_t> acc(num_tables, 0);
for (size_t k : ret)
acc[k] += 1;
{
ofstream fs("load_result.m");
fs << "num_tables = " << num_tables << ';'<<endl;
fs << "ret = " << ret << ';' << endl;
fs << "acc = " << acc << ';' << endl;
}
}
load_result
h = bar(acc, 'histc'); grid on
set(gca, 'xlim', [1 length(acc)])
set(h, 'facecolor','g');
xlabel('table id')
ylabel('prob')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment