Created
February 22, 2018 14:50
-
-
Save buyoh/5c985d36bf14d9a82f8e698842c866ff to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma GCC optimize ("O3") | |
#pragma GCC target ("avx") | |
#include "bits/stdc++.h" // define macro "/D__MAI" | |
using namespace std; | |
typedef long long int ll; | |
#define debug(v) {printf("L%d %s > ",__LINE__,#v);cout<<(v)<<endl;} | |
#define debugv(v) {printf("L%d %s > ",__LINE__,#v);for(auto e:(v)){cout<<e<<" ";}cout<<endl;} | |
#define debuga(m,w) {printf("L%d %s > ",__LINE__,#m);for(int x=0;x<(w);x++){cout<<(m)[x]<<" ";}cout<<endl;} | |
#define debugaa(m,h,w) {printf("L%d %s >\n",__LINE__,#m);for(int y=0;y<(h);y++){for(int x=0;x<(w);x++){cout<<(m)[y][x]<<" ";}cout<<endl;}} | |
#define ALL(v) (v).begin(),(v).end() | |
#define repeat(cnt,l) for(auto cnt=0ll;(cnt)<(l);++(cnt)) | |
#define rrepeat(cnt,l) for(auto cnt=(l)-1;0<=(cnt);--(cnt)) | |
#define iterate(cnt,b,e) for(auto cnt=(b);(cnt)!=(e);++(cnt)) | |
#define diterate(cnt,b,e) for(auto cnt=(b);(cnt)!=(e);--(cnt)) | |
#define MD 1000000007ll | |
#define PI 3.1415926535897932384626433832795 | |
template<typename T1, typename T2> ostream& operator <<(ostream &o, const pair<T1, T2> p) { o << "(" << p.first << ":" << p.second << ")"; return o; } | |
template<typename T> T& maxset(T& to, const T& val) { return to = max(to, val); } | |
template<typename T> T& minset(T& to, const T& val) { return to = min(to, val); } | |
void bye(string s, int code = 0) { cout << s << endl; exit(code); } | |
mt19937_64 randdev(8901016); | |
inline ll rand_range(ll l, ll h) { | |
return uniform_int_distribution<ll>(l, h)(randdev); | |
} | |
#if defined(_WIN32) || defined(_WIN64) | |
#define getchar_unlocked _getchar_nolock | |
#define putchar_unlocked _putchar_nolock | |
#elif defined(__GNUC__) | |
#else | |
#define getchar_unlocked getchar | |
#define putchar_unlocked putchar | |
#endif | |
namespace { | |
#define isvisiblechar(c) (0x21<=(c)&&(c)<=0x7E) | |
class MaiScanner { | |
public: | |
template<typename T> void input_integer(T& var) { | |
var = 0; T sign = 1; | |
int cc = getchar_unlocked(); | |
for (; cc<'0' || '9'<cc; cc = getchar_unlocked()) | |
if (cc == '-') sign = -1; | |
for (; '0' <= cc && cc <= '9'; cc = getchar_unlocked()) | |
var = (var << 3) + (var << 1) + cc - '0'; | |
var = var * sign; | |
} | |
inline int c() { return getchar_unlocked(); } | |
inline MaiScanner& operator>>(int& var) { input_integer<int>(var); return *this; } | |
inline MaiScanner& operator>>(long long& var) { input_integer<long long>(var); return *this; } | |
inline MaiScanner& operator>>(string& var) { | |
int cc = getchar_unlocked(); | |
for (; !isvisiblechar(cc); cc = getchar_unlocked()); | |
for (; isvisiblechar(cc); cc = getchar_unlocked()) | |
var.push_back(cc); | |
return *this; | |
} | |
template<typename IT> void in(IT begin, IT end) { for (auto it = begin; it != end; ++it) *this >> *it; } | |
}; | |
class MaiPrinter { | |
public: | |
template<typename T> | |
void output_integer(T var) { | |
if (var == 0) { putchar_unlocked('0'); return; } | |
if (var < 0) | |
putchar_unlocked('-'), | |
var = -var; | |
char stack[32]; int stack_p = 0; | |
while (var) | |
stack[stack_p++] = '0' + (var % 10), | |
var /= 10; | |
while (stack_p) | |
putchar_unlocked(stack[--stack_p]); | |
} | |
inline MaiPrinter& operator<<(char c) { putchar_unlocked(c); return *this; } | |
inline MaiPrinter& operator<<(int var) { output_integer<int>(var); return *this; } | |
inline MaiPrinter& operator<<(long long var) { output_integer<long long>(var); return *this; } | |
inline MaiPrinter& operator<<(char* str_p) { while (*str_p) putchar_unlocked(*(str_p++)); return *this; } | |
inline MaiPrinter& operator<<(const string& str) { | |
const char* p = str.c_str(); | |
const char* l = p + str.size(); | |
while (p < l) putchar_unlocked(*p++); | |
return *this; | |
} | |
template<typename IT> void join(IT begin, IT end, char sep = '\n') { for (auto it = begin; it != end; ++it) *this << *it << sep; } | |
}; | |
} | |
MaiScanner scanner; | |
MaiPrinter printer; | |
template<typename input_t, typename output_t> | |
class DecisionTree { | |
using comparator_t = function<bool(input_t)>; | |
bool leaf_; | |
comparator_t comparator_; | |
output_t out_; | |
unique_ptr<DecisionTree> childlen_[2]; | |
public: | |
DecisionTree() :leaf_(true), comparator_(nullptr){} | |
inline DecisionTree& operator[](int i) { return *childlen_[i]; } | |
inline bool leaf() { return leaf_; } | |
inline void generate_leaf(output_t out) { | |
leaf_ = true; | |
childlen_[0].release(); | |
childlen_[1].release(); | |
out_ = out; | |
} | |
inline void generate_branch(comparator_t comparator) { | |
leaf_ = false; | |
childlen_[0].reset(new DecisionTree()); | |
childlen_[1].reset(new DecisionTree()); | |
comparator_ = comparator; | |
} | |
output_t eval(input_t val) { | |
return leaf() ? out_ : childlen_[comparator_(val)]->eval(val); | |
} | |
}; | |
namespace Program { | |
// 乱数生成用 | |
mt19937_64 _randdev(8901016); | |
inline double rand_real(double low, double high){ return uniform_real_distribution<double>(low, high)(_randdev); } | |
inline int rand_int(int low, int high) { return uniform_int_distribution<int>(low, high)(_randdev); } | |
// データ の次元 | |
const int dimension = 2; | |
// データの制約 | |
const double value_upper = 1.0; | |
const double value_lower = -1.0; | |
// 教師データ | |
using data_t = array<double, dimension>; | |
vector<pair<int, data_t>> data; | |
// ランダムフォレスト | |
vector<DecisionTree<data_t, pair<int,int>>> randomforest; | |
// 正解となるクラス分け関数 | |
inline int func(const data_t& x) { | |
return x[0] * x[0] + x[1] * x[1] < 0.49 ? 1 : 0; | |
}; | |
// 正解となるクラス分け関数を元に教師データを作成する | |
void generate_inputdata() { | |
// 教師データ数 | |
const int data_size = 1000; | |
data.reserve(data_size); | |
for (int i = 0; i < data_size; ++i) { | |
data_t x; | |
for (auto& e : x) e = rand_real(value_lower, value_upper); | |
data.emplace_back(func(x), x); | |
} | |
} | |
// 教師データから学習する | |
void learn() { | |
const int tree_num = 100; | |
function<void(decltype(randomforest)::value_type&, vector<int>&)> | |
build_dfs = [&build_dfs](decltype(randomforest)::value_type& dt, vector<int>& selected){ | |
int cnt = accumulate(selected.begin(), selected.end(), 0); | |
if (cnt < 100) { | |
pair<int, int> p; | |
for (int i = 0; i < selected.size(); ++i) { | |
if (selected[i]) { | |
(data[i].first == 0 ? p.first : p.second)++; | |
} | |
} | |
dt.generate_leaf(p); | |
} | |
else { | |
int sel = rand_int(0, dimension - 1); | |
double threshold = rand_real(value_lower, value_upper); | |
dt.generate_branch([sel, threshold](data_t d) {return threshold <= d[sel]; }); | |
{ | |
vector<int> selected_new = selected; | |
for (int i = 0; i < data.size(); ++i) | |
selected_new[i] &= threshold <= data[i].second[sel]; | |
build_dfs(dt[1], selected_new); | |
} | |
{ | |
vector<int> selected_new = selected; | |
for (int i = 0; i < data.size(); ++i) | |
selected_new[i] &= !(threshold <= data[i].second[sel]); | |
build_dfs(dt[0], selected_new); | |
} | |
} | |
}; | |
vector<int> selected_all(data.size(), 1); | |
for (int i = 0; i < tree_num; ++i) { | |
randomforest.emplace_back(); | |
build_dfs(randomforest.back(), selected_all); | |
} | |
} | |
int predict(data_t x) { | |
pair<int, int> rate; | |
for (auto& dt : randomforest) { | |
auto y = dt.eval(x); | |
rate.first += y.first; | |
rate.second += y.second; | |
} | |
return rate.first < rate.second; | |
} | |
void test() { | |
int correct = 0, incorrect = 0; | |
for (int no = 0; no < 10000; ++no) { | |
data_t x; | |
for (auto& e : x) e = rand_real(value_lower, value_upper); | |
int predicted = predict(x); | |
int truth = func(x); | |
(predicted == truth ? correct : incorrect)++; | |
} | |
cout << "correct : " << correct << endl; | |
cout << "incorrect : " << incorrect << endl; | |
} | |
} | |
int main() { | |
Program::generate_inputdata(); | |
Program::learn(); | |
Program::test(); | |
string s; | |
cin >> s; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment