Skip to content

Instantly share code, notes, and snippets.

@buyoh
Created February 22, 2018 14:50
Show Gist options
  • Save buyoh/5c985d36bf14d9a82f8e698842c866ff to your computer and use it in GitHub Desktop.
Save buyoh/5c985d36bf14d9a82f8e698842c866ff to your computer and use it in GitHub Desktop.
#pragma GCC optimize ("O3")
#pragma GCC target ("avx")
#include "bits/stdc++.h" // define macro "/D__MAI"
using namespace std;
typedef long long int ll;
#define debug(v) {printf("L%d %s > ",__LINE__,#v);cout<<(v)<<endl;}
#define debugv(v) {printf("L%d %s > ",__LINE__,#v);for(auto e:(v)){cout<<e<<" ";}cout<<endl;}
#define debuga(m,w) {printf("L%d %s > ",__LINE__,#m);for(int x=0;x<(w);x++){cout<<(m)[x]<<" ";}cout<<endl;}
#define debugaa(m,h,w) {printf("L%d %s >\n",__LINE__,#m);for(int y=0;y<(h);y++){for(int x=0;x<(w);x++){cout<<(m)[y][x]<<" ";}cout<<endl;}}
#define ALL(v) (v).begin(),(v).end()
#define repeat(cnt,l) for(auto cnt=0ll;(cnt)<(l);++(cnt))
#define rrepeat(cnt,l) for(auto cnt=(l)-1;0<=(cnt);--(cnt))
#define iterate(cnt,b,e) for(auto cnt=(b);(cnt)!=(e);++(cnt))
#define diterate(cnt,b,e) for(auto cnt=(b);(cnt)!=(e);--(cnt))
#define MD 1000000007ll
#define PI 3.1415926535897932384626433832795
template<typename T1, typename T2> ostream& operator <<(ostream &o, const pair<T1, T2> p) { o << "(" << p.first << ":" << p.second << ")"; return o; }
template<typename T> T& maxset(T& to, const T& val) { return to = max(to, val); }
template<typename T> T& minset(T& to, const T& val) { return to = min(to, val); }
void bye(string s, int code = 0) { cout << s << endl; exit(code); }
mt19937_64 randdev(8901016);
inline ll rand_range(ll l, ll h) {
return uniform_int_distribution<ll>(l, h)(randdev);
}
#if defined(_WIN32) || defined(_WIN64)
#define getchar_unlocked _getchar_nolock
#define putchar_unlocked _putchar_nolock
#elif defined(__GNUC__)
#else
#define getchar_unlocked getchar
#define putchar_unlocked putchar
#endif
namespace {
#define isvisiblechar(c) (0x21<=(c)&&(c)<=0x7E)
class MaiScanner {
public:
template<typename T> void input_integer(T& var) {
var = 0; T sign = 1;
int cc = getchar_unlocked();
for (; cc<'0' || '9'<cc; cc = getchar_unlocked())
if (cc == '-') sign = -1;
for (; '0' <= cc && cc <= '9'; cc = getchar_unlocked())
var = (var << 3) + (var << 1) + cc - '0';
var = var * sign;
}
inline int c() { return getchar_unlocked(); }
inline MaiScanner& operator>>(int& var) { input_integer<int>(var); return *this; }
inline MaiScanner& operator>>(long long& var) { input_integer<long long>(var); return *this; }
inline MaiScanner& operator>>(string& var) {
int cc = getchar_unlocked();
for (; !isvisiblechar(cc); cc = getchar_unlocked());
for (; isvisiblechar(cc); cc = getchar_unlocked())
var.push_back(cc);
return *this;
}
template<typename IT> void in(IT begin, IT end) { for (auto it = begin; it != end; ++it) *this >> *it; }
};
class MaiPrinter {
public:
template<typename T>
void output_integer(T var) {
if (var == 0) { putchar_unlocked('0'); return; }
if (var < 0)
putchar_unlocked('-'),
var = -var;
char stack[32]; int stack_p = 0;
while (var)
stack[stack_p++] = '0' + (var % 10),
var /= 10;
while (stack_p)
putchar_unlocked(stack[--stack_p]);
}
inline MaiPrinter& operator<<(char c) { putchar_unlocked(c); return *this; }
inline MaiPrinter& operator<<(int var) { output_integer<int>(var); return *this; }
inline MaiPrinter& operator<<(long long var) { output_integer<long long>(var); return *this; }
inline MaiPrinter& operator<<(char* str_p) { while (*str_p) putchar_unlocked(*(str_p++)); return *this; }
inline MaiPrinter& operator<<(const string& str) {
const char* p = str.c_str();
const char* l = p + str.size();
while (p < l) putchar_unlocked(*p++);
return *this;
}
template<typename IT> void join(IT begin, IT end, char sep = '\n') { for (auto it = begin; it != end; ++it) *this << *it << sep; }
};
}
MaiScanner scanner;
MaiPrinter printer;
template<typename input_t, typename output_t>
class DecisionTree {
using comparator_t = function<bool(input_t)>;
bool leaf_;
comparator_t comparator_;
output_t out_;
unique_ptr<DecisionTree> childlen_[2];
public:
DecisionTree() :leaf_(true), comparator_(nullptr){}
inline DecisionTree& operator[](int i) { return *childlen_[i]; }
inline bool leaf() { return leaf_; }
inline void generate_leaf(output_t out) {
leaf_ = true;
childlen_[0].release();
childlen_[1].release();
out_ = out;
}
inline void generate_branch(comparator_t comparator) {
leaf_ = false;
childlen_[0].reset(new DecisionTree());
childlen_[1].reset(new DecisionTree());
comparator_ = comparator;
}
output_t eval(input_t val) {
return leaf() ? out_ : childlen_[comparator_(val)]->eval(val);
}
};
namespace Program {
// 乱数生成用
mt19937_64 _randdev(8901016);
inline double rand_real(double low, double high){ return uniform_real_distribution<double>(low, high)(_randdev); }
inline int rand_int(int low, int high) { return uniform_int_distribution<int>(low, high)(_randdev); }
// データ の次元
const int dimension = 2;
// データの制約
const double value_upper = 1.0;
const double value_lower = -1.0;
// 教師データ
using data_t = array<double, dimension>;
vector<pair<int, data_t>> data;
// ランダムフォレスト
vector<DecisionTree<data_t, pair<int,int>>> randomforest;
// 正解となるクラス分け関数
inline int func(const data_t& x) {
return x[0] * x[0] + x[1] * x[1] < 0.49 ? 1 : 0;
};
// 正解となるクラス分け関数を元に教師データを作成する
void generate_inputdata() {
// 教師データ数
const int data_size = 1000;
data.reserve(data_size);
for (int i = 0; i < data_size; ++i) {
data_t x;
for (auto& e : x) e = rand_real(value_lower, value_upper);
data.emplace_back(func(x), x);
}
}
// 教師データから学習する
void learn() {
const int tree_num = 100;
function<void(decltype(randomforest)::value_type&, vector<int>&)>
build_dfs = [&build_dfs](decltype(randomforest)::value_type& dt, vector<int>& selected){
int cnt = accumulate(selected.begin(), selected.end(), 0);
if (cnt < 100) {
pair<int, int> p;
for (int i = 0; i < selected.size(); ++i) {
if (selected[i]) {
(data[i].first == 0 ? p.first : p.second)++;
}
}
dt.generate_leaf(p);
}
else {
int sel = rand_int(0, dimension - 1);
double threshold = rand_real(value_lower, value_upper);
dt.generate_branch([sel, threshold](data_t d) {return threshold <= d[sel]; });
{
vector<int> selected_new = selected;
for (int i = 0; i < data.size(); ++i)
selected_new[i] &= threshold <= data[i].second[sel];
build_dfs(dt[1], selected_new);
}
{
vector<int> selected_new = selected;
for (int i = 0; i < data.size(); ++i)
selected_new[i] &= !(threshold <= data[i].second[sel]);
build_dfs(dt[0], selected_new);
}
}
};
vector<int> selected_all(data.size(), 1);
for (int i = 0; i < tree_num; ++i) {
randomforest.emplace_back();
build_dfs(randomforest.back(), selected_all);
}
}
int predict(data_t x) {
pair<int, int> rate;
for (auto& dt : randomforest) {
auto y = dt.eval(x);
rate.first += y.first;
rate.second += y.second;
}
return rate.first < rate.second;
}
void test() {
int correct = 0, incorrect = 0;
for (int no = 0; no < 10000; ++no) {
data_t x;
for (auto& e : x) e = rand_real(value_lower, value_upper);
int predicted = predict(x);
int truth = func(x);
(predicted == truth ? correct : incorrect)++;
}
cout << "correct : " << correct << endl;
cout << "incorrect : " << incorrect << endl;
}
}
int main() {
Program::generate_inputdata();
Program::learn();
Program::test();
string s;
cin >> s;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment