Last active
December 26, 2017 03:10
-
-
Save odashi/83bef0a8d24d293811bd7852ba1916ad to your computer and use it in GitHub Desktop.
primitiv examples for Qiita (C++11/Python3)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// 実行方法: | |
// g++ -std=c++11 xor.cc -lprimitiv && ./a.out | |
#include <cstdio> | |
#include <iostream> | |
#include <primitiv/primitiv.h> | |
using namespace primitiv; | |
namespace D = primitiv::devices; | |
namespace F = primitiv::functions; | |
namespace I = primitiv::initializers; | |
namespace O = primitiv::optimizers; | |
int main() { | |
// デバイスと計算グラフの設定 | |
devices::Naive dev; | |
Device::set_default(dev); | |
Graph g; | |
Graph::set_default(g); | |
// 入力データ | |
std::vector<float> input_data { | |
1, 1, // 第一象限 | |
-1, 1, // 第二象限 | |
-1, -1, // 第三象限 | |
1, -1, // 第四象限 | |
}; | |
// 対応する正解 | |
std::vector<float> label_data { | |
1, // 第一象限 | |
-1, // 第二象限 | |
1, // 第三象限 | |
-1, // 第四象限 | |
}; | |
// パラメータ | |
const int N = 8; | |
Parameter pw({1, N}, I::XavierUniform()); | |
Parameter pb({}, I::Constant(0)); | |
Parameter pu({N, 2}, I::XavierUniform()); | |
Parameter pc({N}, I::Constant(0)); | |
// 学習器 | |
O::SGD optimizer(0.5); | |
optimizer.add(pw, pb, pu, pc); | |
// ネットワークの定義 | |
auto build_graph = [&] { | |
auto x = F::input<Node>(Shape({2}, 4), input_data); | |
auto w = F::parameter<Node>(pw); | |
auto b = F::parameter<Node>(pb); | |
auto u = F::parameter<Node>(pu); | |
auto c = F::parameter<Node>(pc); | |
auto h = F::tanh(F::matmul(u, x) + c); | |
return F::tanh(F::matmul(w, h) + b); | |
}; | |
// 損失の定義 | |
auto calc_loss = [&](Node y) { | |
auto t = F::input<Node>(Shape({}, 4), label_data); | |
auto diff = y - t; | |
return F::batch::mean(diff * diff); | |
}; | |
// 学習ループ | |
for (int epoch = 0; epoch < 20; ++epoch) { | |
std::cout << epoch << ' '; | |
// グラフの初期化 | |
g.clear(); | |
// 出力の計算 | |
auto y = build_graph(); | |
for (float val : y.to_vector()) { | |
std::printf("%+.6f, ", val); | |
} | |
// 損失の計算 | |
auto loss = calc_loss(y); | |
std::printf("loss=%.6f", loss.to_float()); | |
std::cout << std::endl; | |
// 勾配の計算・パラメータの更新 | |
optimizer.reset_gradients(); | |
loss.backward(); | |
optimizer.update(); | |
} | |
return 0; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# 実行方法: ./xor.py | |
import numpy as np | |
from primitiv import * | |
from primitiv import devices as D | |
from primitiv import functions as F | |
from primitiv import initializers as I | |
from primitiv import optimizers as O | |
# デバイスと計算グラフの設定 | |
dev = devices.Naive() | |
Device.set_default(dev) | |
g = Graph() | |
Graph.set_default(g) | |
# 入力データ | |
input_data = [ | |
np.array([[ 1], [ 1]]), # 第一象限 | |
np.array([[-1], [ 1]]), # 第二象限 | |
np.array([[-1], [-1]]), # 第三象限 | |
np.array([[ 1], [-1]]), # 第四象限 | |
] | |
# 対応する正解 | |
label_data = [ | |
np.array([ 1]), # 第一象限 | |
np.array([-1]), # 第二象限 | |
np.array([ 1]), # 第三象限 | |
np.array([-1]), # 第四象限 | |
] | |
# パラメータ | |
N = 8 | |
pw = Parameter([1, N], I.XavierUniform()) | |
pb = Parameter([], I.Constant(0)) | |
pu = Parameter([N, 2], I.XavierUniform()) | |
pc = Parameter([N], I.Constant(0)) | |
# 学習器 | |
optimizer = O.SGD(0.5) | |
optimizer.add(pw, pb, pu, pc) | |
# ネットワークの定義 | |
def build_graph(): | |
x = F.input(input_data) | |
w = F.parameter(pw) | |
b = F.parameter(pb) | |
u = F.parameter(pu) | |
c = F.parameter(pc) | |
h = F.tanh(u @ x + c) | |
return F.tanh(w @ h + b) | |
# 損失の定義 | |
def calc_loss(y): | |
t = F.input(label_data) | |
diff = y - t | |
return F.batch.mean(diff * diff) | |
# 学習ループ | |
for epoch in range(20): | |
print(epoch, end=' ') | |
# グラフの初期化 | |
g.clear() | |
# 出力の計算 | |
y = build_graph() | |
for val in y.to_list(): | |
print('{:+.6f},'.format(val), end=' ') | |
# 損失の計算 | |
loss = calc_loss(y) | |
print('loss={:.6f}'.format(loss.to_float())) | |
# 勾配の計算・パラメータの更新 | |
optimizer.reset_gradients() | |
loss.backward() | |
optimizer.update() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment