Created
September 5, 2016 15:06
-
-
Save 43x2/9db1da7789fa93e00b1fdea2d3e74669 to your computer and use it in GitHub Desktop.
半加算器ニューラルネットワークの学習 (「誤差逆伝播法をはじめからていねいに」 ソースコード)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// main.cpp is placed in PUBLIC DOMAIN. | |
// | |
#include <cmath> | |
#include <vector> | |
#include <random> | |
#include <iostream> | |
#include "SimpleMatrix.h" | |
// 入力データと教師データ | |
struct Data | |
{ | |
struct | |
{ | |
double A; // 入力 a | |
double B; // 入力 b | |
} input; | |
struct | |
{ | |
double S; // 出力 s (和) | |
double C; // 出力 c (キャリー) | |
} supervisor; | |
}; | |
int main( int /*argc*/, char * /*argv*/[] ) | |
{ | |
// 入力データと教師データ | |
const std::vector< Data > datas { | |
// input supervisor | |
{ { 0.0, 0.0 }, { 0.0, 0.0 } }, | |
{ { 0.0, 1.0 }, { 1.0, 0.0 } }, | |
{ { 1.0, 0.0 }, { 1.0, 0.0 } }, | |
{ { 1.0, 1.0 }, { 0.0, 1.0 } } | |
}; | |
// 乱数生成器・一様分布器 | |
std::random_device device; | |
std::mt19937_64 generator( device() ); | |
std::uniform_real_distribution< double > distributor( -8.0, 8.0 ); | |
// 隠れ層への重み | |
SimpleMatrix< double > weightA( 3, 2 ); | |
for ( auto i = 0; i < 3; ++i ) | |
{ | |
for ( auto j = 0; j < 2; ++j ) | |
{ | |
weightA( i, j ) = distributor( generator ); | |
} | |
} | |
// 隠れ層バイアス | |
SimpleMatrix< double > biasA( 3, 1 ); | |
for ( auto i = 0; i < 3; ++i ) | |
{ | |
biasA( i, 0 ) = distributor( generator ); | |
} | |
// 出力層への重み | |
SimpleMatrix< double > weightB( 2, 3 ); | |
for ( auto i = 0; i < 2; ++i ) | |
{ | |
for ( auto j = 0; j < 3; ++j ) | |
{ | |
weightB( i, j ) = distributor( generator ); | |
} | |
} | |
// 出力層バイアス | |
SimpleMatrix< double > biasB( 2, 1 ); | |
for ( auto i = 0; i < 2; ++i ) | |
{ | |
biasB( i, 0 ) = distributor( generator ); | |
} | |
// 学習回数 | |
constexpr std::size_t training_times = 100000; | |
// 学習率 | |
constexpr double learning_rate = 0.3; | |
// | |
// 学習フェーズ | |
// | |
for ( auto i = 0; i < training_times; ++i ) | |
{ | |
// 誤差 | |
auto error = 0.0; | |
// 誤差に対する隠れ層への重みの偏微分 (∂E/∂WA) | |
SimpleMatrix< double > dE_dWA( 3, 2, 0.0 ); | |
// 誤差に対する隠れ層バイアスの偏微分 (∂E/∂bA) | |
SimpleMatrix< double > dE_dbA( 3, 1, 0.0 ); | |
// 誤差に対する出力層への重みの偏微分 (∂E/∂WB) | |
SimpleMatrix< double > dE_dWB( 2, 3, 0.0 ); | |
// 誤差に対する出力層バイアスの偏微分 (∂E/∂bB) | |
SimpleMatrix< double > dE_dbB( 2, 1, 0.0 ); | |
for ( auto & data : datas ) | |
{ | |
// 入力層 | |
SimpleMatrix< double > input( 2, 1, { data.input.A, data.input.B } ); | |
// 隠れ層を計算する | |
auto hidden = weightA * input + biasA; | |
hidden( 0, 0 ) = 1.0 / ( 1.0 + std::exp( -hidden( 0, 0 ) ) ); // | |
hidden( 1, 0 ) = 1.0 / ( 1.0 + std::exp( -hidden( 1, 0 ) ) ); // | |
hidden( 2, 0 ) = 1.0 / ( 1.0 + std::exp( -hidden( 2, 0 ) ) ); // シグモイド関数による非線形変換 | |
// 出力層を計算する | |
auto output = weightB * hidden + biasB; | |
output( 0, 0 ) = 1.0 / ( 1.0 + std::exp( -output( 0, 0 ) ) ); // | |
output( 1, 0 ) = 1.0 / ( 1.0 + std::exp( -output( 1, 0 ) ) ); // シグモイド関数による非線形変換 | |
// 誤差 | |
error += 0.5 * ( std::pow( output( 0, 0 ) - data.supervisor.S, 2.0 ) + std::pow( output( 1, 0 ) - data.supervisor.C, 2.0 ) ); | |
// ∂E/∂bB | |
SimpleMatrix< double > supervisor( 2, 1, { data.supervisor.S, data.supervisor.C } ); | |
auto temp0 = output - supervisor; | |
auto temp1 = temp0.element_product( output ).element_product( SimpleMatrix< double >( 2, 1, 1.0 ) - output ); | |
dE_dbB += temp1; | |
// ∂E/∂WB | |
auto temp2 = temp1 * hidden.transpose(); | |
dE_dWB += temp2; | |
// ∂E/∂bA | |
auto temp3 = ( temp1.transpose() * weightB ).transpose(); | |
auto temp4 = temp3.element_product( hidden ).element_product( SimpleMatrix< double >( 3, 1, 1.0 ) - hidden ); | |
dE_dbA += temp4; | |
// ∂E/∂WA | |
auto temp5 = temp4 * input.transpose(); | |
dE_dWA += temp5; | |
} | |
// 100 回ごとに報告 | |
if ( i % 100 == 0 ) | |
{ | |
std::cout << i << " -- " << error << std::endl; | |
} | |
// 重みとバイアスに反映 | |
weightA -= ( dE_dWA *= learning_rate ); | |
biasA -= ( dE_dbA *= learning_rate ); | |
weightB -= ( dE_dWB *= learning_rate ); | |
biasB -= ( dE_dbB *= learning_rate ); | |
} | |
// | |
// 結果 (学習フェーズとほぼ同じコードなのでコメントは省略) | |
// | |
auto error = 0.0; | |
for ( auto data : datas ) | |
{ | |
SimpleMatrix< double > input( 2, 1, { data.input.A, data.input.B } ); | |
auto hidden = weightA * input + biasA; | |
hidden( 0, 0 ) = 1.0 / ( 1.0 + std::exp( -hidden( 0, 0 ) ) ); | |
hidden( 1, 0 ) = 1.0 / ( 1.0 + std::exp( -hidden( 1, 0 ) ) ); | |
hidden( 2, 0 ) = 1.0 / ( 1.0 + std::exp( -hidden( 2, 0 ) ) ); | |
auto output = weightB * hidden + biasB; | |
output( 0, 0 ) = 1.0 / ( 1.0 + std::exp( -output( 0, 0 ) ) ); | |
output( 1, 0 ) = 1.0 / ( 1.0 + std::exp( -output( 1, 0 ) ) ); | |
error += 0.5 * ( std::pow( output( 0, 0 ) - data.supervisor.S, 2.0 ) + std::pow( output( 1, 0 ) - data.supervisor.C, 2.0 ) ); | |
} | |
std::cout << training_times << " -- " << error << std::endl; | |
return 0; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SimpleMatrix.h is placed in PUBLIC DOMAIN. | |
// | |
#if !defined( SIMPLE_MATRIX_H ) | |
#define SIMPLE_MATRIX_H | |
#include <cstddef> | |
#include <cassert> | |
#include <vector> | |
#include <utility> | |
// シンプルな行列演算クラス | |
template< class NumberType > | |
class SimpleMatrix | |
{ | |
// | |
// メンバ関数 | |
// | |
public: | |
// コンストラクタ | |
explicit SimpleMatrix( const std::size_t rows, const std::size_t columns ); | |
// コンストラクタ (単一の値で初期化) | |
explicit SimpleMatrix( const std::size_t rows, const std::size_t columns, const NumberType initial_value ); | |
// コンストラクタ (初期化リストで初期化) | |
explicit SimpleMatrix( const std::size_t rows, const std::size_t columns, std::initializer_list< NumberType > initial_values ); | |
// コンストラクタ (ベクタで初期化) | |
explicit SimpleMatrix( const std::size_t rows, const std::size_t columns, const std::vector< NumberType > & initial_values ); | |
// コピーコンストラクタ | |
SimpleMatrix( const SimpleMatrix & ) = default; | |
// ムーブコンストラクタ | |
SimpleMatrix( SimpleMatrix && source ); | |
// デストラクタ | |
~SimpleMatrix() = default; | |
// 代入演算子 | |
SimpleMatrix & operator=( const SimpleMatrix & ) = default; | |
// ムーブ代入演算子 | |
SimpleMatrix & operator=( SimpleMatrix && source ); | |
// すべての要素を指定の値にする | |
SimpleMatrix & fill( const NumberType & value ); | |
// 行数を取得する | |
std::size_t get_number_of_rows() const; | |
// 列数を取得する | |
std::size_t get_number_of_columns() const; | |
// 要素参照 (non-const) | |
NumberType & operator()( const std::size_t row, const std::size_t column ); | |
// 要素参照 (const) | |
const NumberType & operator()( const std::size_t row, const std::size_t column ) const; | |
// 加算 | |
SimpleMatrix operator+( const SimpleMatrix & rhs ) const; | |
// 減算 | |
SimpleMatrix operator-( const SimpleMatrix & rhs ) const; | |
// 乗算 | |
SimpleMatrix operator*( const SimpleMatrix & rhs ) const; | |
// アダマール積 (要素ごとの積) | |
SimpleMatrix element_product( const SimpleMatrix & rhs ) const; | |
// 加算複合代入 | |
SimpleMatrix & operator+=( const SimpleMatrix & rhs ); | |
// 減算複合代入 | |
SimpleMatrix & operator-=( const SimpleMatrix & rhs ); | |
// 乗算複合代入 (対スカラー) | |
SimpleMatrix & operator*=( const NumberType & rhs ); | |
// 除算複合代入 (対スカラー) | |
SimpleMatrix & operator/=( const NumberType & rhs ); | |
// インプレースでアダマール積 (要素ごとの積) | |
SimpleMatrix & element_product_inplace( const SimpleMatrix & rhs ); | |
// 転置 | |
SimpleMatrix transpose() const; | |
// | |
// メンバ変数 | |
// | |
private: | |
// 行数 | |
std::size_t rows_; | |
// 列数 | |
std::size_t columns_; | |
// 保存領域 | |
std::vector< NumberType > buffer_; | |
}; | |
// コンストラクタ | |
template< class NumberType > | |
SimpleMatrix< NumberType >::SimpleMatrix( const std::size_t rows, const std::size_t columns ) | |
: rows_( rows ) | |
, columns_( columns ) | |
, buffer_( rows * columns ) | |
{ | |
// アサーション | |
assert( rows > 0 && columns > 0 ); | |
} | |
// コンストラクタ (単一の値で初期化) | |
template< class NumberType > | |
SimpleMatrix< NumberType >::SimpleMatrix( const std::size_t rows, const std::size_t columns, const NumberType initial_value ) | |
: rows_( rows ) | |
, columns_( columns ) | |
, buffer_( rows * columns, initial_value ) | |
{ | |
// アサーション | |
assert( rows > 0 && columns > 0 ); | |
} | |
// コンストラクタ (初期化リストで初期化) | |
template< class NumberType > | |
SimpleMatrix< NumberType >::SimpleMatrix( const std::size_t rows, const std::size_t columns, std::initializer_list< NumberType > initial_values ) | |
: rows_( rows ) | |
, columns_( columns ) | |
, buffer_( initial_values ) | |
{ | |
// アサーション | |
assert( rows > 0 && columns > 0 && initial_values.size() == rows * columns ); | |
} | |
// コンストラクタ (ベクタで初期化) | |
template< class NumberType > | |
SimpleMatrix< NumberType >::SimpleMatrix( const std::size_t rows, const std::size_t columns, const std::vector< NumberType > & initial_values ) | |
: rows_( rows ) | |
, columns_( columns ) | |
, buffer_( initial_values ) | |
{ | |
// アサーション | |
assert( rows > 0 && columns > 0 && initial_values.size() == rows * columns ); | |
} | |
// ムーブコンストラクタ | |
template< class NumberType > | |
SimpleMatrix< NumberType >::SimpleMatrix( SimpleMatrix && source ) | |
: rows_( source.rows_ ) | |
, columns_( source.columns_ ) | |
, buffer_( std::move( source.buffer_ ) ) | |
{ | |
// source.buffer_ の内容が保証されなくなったので rows_ と columns_ も更新する | |
source.rows_ = source.columns_ = 0; | |
} | |
// ムーブ代入演算子 | |
template< class NumberType > | |
SimpleMatrix< NumberType > & SimpleMatrix< NumberType >::operator=( SimpleMatrix && source ) | |
{ | |
rows_ = source.rows_; | |
columns_ = source.columns_; | |
buffer_ = std::move( source.buffer_ ); | |
// source.buffer_ の内容が保証されなくなったので rows_ と columns_ も更新する | |
source.rows_ = source.columns_ = 0; | |
return *this; | |
} | |
// すべての要素を指定の値にする | |
template< class NumberType > | |
SimpleMatrix< NumberType > & SimpleMatrix< NumberType >::fill( const NumberType & value ) | |
{ | |
std::vector< NumberType > temp( rows_ * columns_, value ); | |
buffer_.swap( temp ); | |
return *this; | |
// 一時的にメモリ消費量がほぼ倍になるので, | |
// 要素数が非常に大きい場合は std::fill を使うほうがよいと思う. | |
// by Atsushi OHTA (2016/7/2) | |
} | |
// 行数を取得する | |
template< class NumberType > | |
std::size_t SimpleMatrix< NumberType >::get_number_of_rows() const | |
{ | |
return rows_; | |
} | |
// 列数を取得する | |
template< class NumberType > | |
std::size_t SimpleMatrix< NumberType >::get_number_of_columns() const | |
{ | |
return columns_; | |
} | |
// 要素参照 (non-const) | |
template< class NumberType > | |
NumberType & SimpleMatrix< NumberType >::operator()( const std::size_t row, const std::size_t column ) | |
{ | |
// アサーション | |
assert( row < rows_ && column < columns_ ); | |
return buffer_[ row * columns_ + column ]; | |
} | |
// 要素参照 (const) | |
template< class NumberType > | |
const NumberType & SimpleMatrix< NumberType >::operator()( const std::size_t row, const std::size_t column ) const | |
{ | |
// アサーション | |
assert( row < rows_ && column < columns_ ); | |
return buffer_[ row * columns_ + column ]; | |
} | |
// 加算 | |
template< class NumberType > | |
SimpleMatrix< NumberType > SimpleMatrix< NumberType >::operator+( const SimpleMatrix & rhs ) const | |
{ | |
SimpleMatrix result( *this ); | |
return result += rhs; | |
} | |
// 減算 | |
template< class NumberType > | |
SimpleMatrix< NumberType > SimpleMatrix< NumberType >::operator-( const SimpleMatrix & rhs ) const | |
{ | |
SimpleMatrix result( *this ); | |
return result -= rhs; | |
} | |
// 乗算 | |
template< class NumberType > | |
SimpleMatrix< NumberType > SimpleMatrix< NumberType >::operator*( const SimpleMatrix & rhs ) const | |
{ | |
// アサーション | |
assert( columns_ == rhs.rows_ ); | |
SimpleMatrix result( rows_, rhs.columns_ ); | |
for ( auto i = 0; i < rows_; ++i ) | |
{ | |
for ( auto j = 0; j < rhs.columns_; ++j ) | |
{ | |
NumberType temp = NumberType( 0 ); | |
for ( auto k = 0; k < columns_; ++k ) | |
{ | |
temp += ( *this )( i, k ) * rhs( k, j ); | |
} | |
result( i, j ) = temp; | |
} | |
} | |
return result; | |
} | |
// アダマール積 (要素ごとの積) | |
template< class NumberType > | |
SimpleMatrix< NumberType > SimpleMatrix< NumberType >::element_product( const SimpleMatrix & rhs ) const | |
{ | |
SimpleMatrix result( *this ); | |
return result.element_product_inplace( rhs ); | |
} | |
// 加算複合代入 | |
template< class NumberType > | |
SimpleMatrix< NumberType > & SimpleMatrix< NumberType >::operator+=( const SimpleMatrix & rhs ) | |
{ | |
// アサーション | |
assert( rows_ == rhs.rows_ && columns_ == rhs.columns_ ); | |
auto it = rhs.buffer_.cbegin(); | |
for ( auto & elem : buffer_ ) | |
{ | |
elem += *it++; | |
} | |
return *this; | |
} | |
// 減算複合代入 | |
template< class NumberType > | |
SimpleMatrix< NumberType > & SimpleMatrix< NumberType >::operator-=( const SimpleMatrix & rhs ) | |
{ | |
// アサーション | |
assert( rows_ == rhs.rows_ && columns_ == rhs.columns_ ); | |
auto it = rhs.buffer_.cbegin(); | |
for ( auto & elem : buffer_ ) | |
{ | |
elem -= *it++; | |
} | |
return *this; | |
} | |
// 乗算複合代入 (対スカラー) | |
template< class NumberType > | |
SimpleMatrix< NumberType > & SimpleMatrix< NumberType >::operator*=( const NumberType & rhs ) | |
{ | |
for ( auto & elem : buffer_ ) | |
{ | |
elem *= rhs; | |
} | |
return *this; | |
} | |
// 除算複合代入 (対スカラー) | |
template< class NumberType > | |
SimpleMatrix< NumberType > & SimpleMatrix< NumberType >::operator/=( const NumberType & rhs ) | |
{ | |
for ( auto & elem : buffer_ ) | |
{ | |
elem /= rhs; | |
} | |
return *this; | |
} | |
// インプレースでアダマール積 (要素ごとの積) | |
template< class NumberType > | |
SimpleMatrix< NumberType > & SimpleMatrix< NumberType >::element_product_inplace( const SimpleMatrix & rhs ) | |
{ | |
// アサーション | |
assert( rows_ == rhs.rows_ && columns_ == rhs.columns_ ); | |
auto it = rhs.buffer_.cbegin(); | |
for ( auto & elem : buffer_ ) | |
{ | |
elem *= *it++; | |
} | |
return *this; | |
} | |
// 転置 | |
template< class NumberType > | |
SimpleMatrix< NumberType > SimpleMatrix< NumberType >::transpose() const | |
{ | |
SimpleMatrix result( columns_, rows_ ); | |
for ( auto i = 0; i < columns_; ++i ) | |
{ | |
for ( auto j = 0; j < rows_; ++j ) | |
{ | |
result( i, j ) = ( *this )( j, i ); | |
} | |
} | |
return result; | |
} | |
#endif // !defined( SIMPLE_MATRIX_H ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment