Last active
October 7, 2016 06:53
-
-
Save remyroez/109bda2c6a5f24ff5f604d3332fbd0e6 to your computer and use it in GitHub Desktop.
マッチ箱の脳 - マッチ箱で作るNN
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <iomanip> | |
#include <memory> | |
#include <array> | |
#include <cmath> | |
// ニューロン(マッチ箱+お菓子) | |
class Neuron | |
{ | |
public: | |
// ctor | |
Neuron(int value = 10, int data = 0) : _value(value), _data(data) {} | |
// 値段 | |
int data() const { return _data; } | |
void data(int n) { _data = n; } | |
// マッチ棒の数 | |
int value() const { return _value; } | |
void value(int n) { _value = n; } | |
// インクリメント・デクリメント | |
void inc(int n = 1) { _value += n; } | |
void dec(int n = 1) { _value -= n; } | |
// 渡されたマッチ棒の数 | |
int input() const { return _input; } | |
void input(int n) { _input = n; } | |
// 渡されたマッチ棒の数>箱の中のマッチ棒の数なら、興奮する | |
int output() const { return (input() > value()) ? 1 : 0; } | |
private: | |
int _value; | |
int _input = 0; | |
int _data = 0; | |
}; | |
int main() | |
{ | |
// マッチ箱+お菓子(の値段) | |
std::array<std::shared_ptr<Neuron>, 3> list = { | |
std::make_shared<Neuron>(1, 310), | |
std::make_shared<Neuron>(3, 220), | |
std::make_shared<Neuron>(8, 70) | |
}; | |
// マッチ箱A(答え) | |
auto answer = std::make_shared<Neuron>(6); | |
// 値段計算機(指定した値段以下に抑える) | |
auto tester = std::make_shared<Neuron>(500); | |
// 何回か繰り返す | |
for (int i = 0; i < 10; ++i) | |
{ | |
// 世代 | |
std::cout << "generation: " << i << std::endl; | |
// 現在のマッチ箱の情報を表示 | |
for (size_t j = 0; j < list.size(); ++j) | |
{ | |
auto n = list[j]; | |
std::cout << "unit-" << j << ": " << n->value() << std::endl; | |
} | |
std::cout << "answer: " << answer->value() << std::endl; | |
// 間違えた回数 | |
int ng = 0; | |
// 問題を総当りで評価する | |
for (int q = 0; q < std::pow(2, list.size()); ++q) | |
{ | |
// 渡すマッチ棒の数と値段の合計を出す | |
int signal = 0; | |
int value = 0; | |
for (size_t j = 0; j < list.size(); ++j) | |
{ | |
if ((q & (1 << j)) == 0) continue; | |
auto n = list[j]; | |
signal += n->value(); | |
value += n->data(); | |
} | |
// マッチ棒を渡す | |
answer->input(signal); | |
// 値段が範囲内かどうか計算する | |
tester->input(value); | |
int voutput = tester->output(); | |
// 結果を表示 | |
std::cout << q << " - (" << std::setw(3) << value << " <= " << std::setw(3) << tester->value() << ")" | |
<< ", correct: " << voutput | |
<< ", answer: " << answer->output() | |
<< ", " << ((voutput == answer->output()) ? "ok" : "ng") << std::endl; | |
// ペナルティ | |
if (voutput == answer->output()) { | |
// ok | |
} else if ((voutput == 0) && (answer->output() != 0)) { | |
// ng 1: 間違いタイプ1「正しい買い方だったのに、NGと判断してしまった」 | |
for (size_t j = 0; j < list.size(); ++j) | |
{ | |
if ((q & (1 << j)) == 0) continue; | |
list[j]->dec(); | |
} | |
answer->inc(); | |
ng++; | |
} else if ((voutput != 0) && (answer->output() == 0)) { | |
// ng 2: 間違いタイプ2「間違った買い方だったのに、OKと判断してしまった」 | |
for (size_t j = 0; j < list.size(); ++j) | |
{ | |
if ((q & (1 << j)) == 0) continue; | |
list[j]->inc(); | |
} | |
answer->dec(); | |
ng++; | |
} | |
} | |
if (ng == 0) break; | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment