Created
October 17, 2016 10:10
-
-
Save kaityo256/6d7e13c6b75976ad20e8ea10fc63bd4e to your computer and use it in GitHub Desktop.
Chainerで学習したモデルをC++で読み込む ref: http://qiita.com/kaityo256/items/f1e2c8e38cbf8ffd8c09
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import print_function | |
| import struct | |
| import numpy as np | |
| import chainer | |
| import chainer.functions as F | |
| import chainer.links as L | |
| from chainer import training | |
| from chainer.training import extensions | |
| # Network definition | |
| class MLP(chainer.Chain): | |
| def __init__(self, n_units, n_out): | |
| super(MLP, self).__init__( | |
| l1 = L.Linear(None, n_units), | |
| l2 = L.Linear(None, n_out) | |
| ) | |
| def __call__(self, x): | |
| return self.l2(F.relu(self.l1(x))) | |
| unit = 3 | |
| model = L.Classifier(MLP(unit, 2)) | |
| chainer.serializers.load_npz('and.model', model) | |
| d = bytearray() | |
| for v in model.predictor.l1.W.data.reshape(2*unit): | |
| d += struct.pack('f',v) | |
| for v in model.predictor.l1.b.data: | |
| d += struct.pack('f',v) | |
| for v in model.predictor.l2.W.data.reshape(unit*2): | |
| d += struct.pack('f',v) | |
| for v in model.predictor.l2.b.data: | |
| d += struct.pack('f',v) | |
| open("and.dat",'w').write(d); | |
| # Results | |
| x = np.array([[0,0],[0,1],[1,0],[1,1]],dtype=np.float32) | |
| y = model.predictor(x).data | |
| for i in range(4): | |
| print (x[i],np.argmax(y[i]),y[i]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| epoch main/loss validation/main/loss main/accuracy validation/main/accuracy | |
| 1 0.659301 0.609888 0.6771 0.75 | |
| 2 0.573191 0.535238 0.9976 1 | |
| 3 0.495501 0.453446 1 1 | |
| 4 0.413051 0.372437 1 1 | |
| 5 0.335372 0.299022 1 1 | |
| 6 0.267259 0.236726 1 1 | |
| 7 0.210785 0.186187 1 1 | |
| 8 0.165747 0.146453 1 1 | |
| 9 0.130642 0.115821 1 1 | |
| 10 0.103705 0.0923906 1 1 | |
| 11 0.0831349 0.0744629 1 1 | |
| 12 0.0673614 0.0607088 1 1 | |
| 13 0.0551783 0.050024 1 1 | |
| 14 0.0456828 0.0416145 1 1 | |
| 15 0.0381965 0.0349596 1 1 | |
| 16 0.0322445 0.0296493 1 1 | |
| 17 0.0274508 0.0253534 1 1 | |
| 18 0.0235528 0.0218315 1 1 | |
| 19 0.020354 0.0189316 1 1 | |
| 20 0.0177069 0.0165223 1 1 | |
| [ 0. 0.] 0 [ 4.44186878 -2.4494648 ] | |
| [ 0. 1.] 0 [ 3.35240507 -0.62914503] | |
| [ 1. 0.] 0 [ 2.42167044 -1.39461792] | |
| [ 1. 1.] 1 [-0.70726597 2.97649908] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| float relu(float x) { | |
| return (x > 0) ? x : 0; | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| $ g++ import.cpp | |
| $ ./a.out | |
| [0.000000 0.000000] 0: [4.441869 -2.449465] | |
| [0.000000 1.000000] 0: [3.352405 -0.629145] | |
| [1.000000 0.000000] 0: [2.421670 -1.394618] | |
| [1.000000 1.000000] 1: [-0.707266 2.976499] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| $ zipinfo and.model | |
| Archive: and.model | |
| Zip file size: 815 bytes, number of entries: 4 | |
| -rw------- 2.0 unx 104 b- defN 16-Oct-17 18:33 predictor/l2/W.npy | |
| -rw------- 2.0 unx 92 b- defN 16-Oct-17 18:33 predictor/l1/b.npy | |
| -rw------- 2.0 unx 88 b- defN 16-Oct-17 18:33 predictor/l2/b.npy | |
| -rw------- 2.0 unx 104 b- defN 16-Oct-17 18:33 predictor/l1/W.npy | |
| 4 files, 388 bytes uncompressed, 345 bytes compressed: 11.1% |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| unit = 3 | |
| model = L.Classifier(MLP(unit, 2)) | |
| chainer.serializers.load_npz('and.model', model) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| d = bytearray() | |
| for v in model.predictor.l1.W.data.reshape(2*unit): | |
| d += struct.pack('f',v) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| open("and.dat",'w').write(d); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| $ python export.py | |
| [ 0. 0.] 0 [ 4.44186878 -2.4494648 ] | |
| [ 0. 1.] 0 [ 3.35240507 -0.62914503] | |
| [ 1. 0.] 0 [ 2.42167044 -1.39461792] | |
| [ 1. 1.] 1 [-0.70726597 2.97649908] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| y = W x + b |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| f(x) := \max(x,0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include <iostream> | |
| #include <fstream> | |
| #include <vector> | |
| typedef std::vector<float> vf; | |
| class Link { | |
| private: | |
| vf W; | |
| vf b; | |
| float relu(float x) { | |
| return (x > 0) ? x : 0; | |
| } | |
| const int n_in, n_out; | |
| public: | |
| Link(int in, int out) : n_in(in), n_out(out) { | |
| W.resize(n_in * n_out); | |
| b.resize(n_out); | |
| } | |
| void read(std::ifstream &ifs) { | |
| ifs.read((char*)W.data(), sizeof(float)*n_in * n_out); | |
| ifs.read((char*)b.data(), sizeof(float)*n_out); | |
| } | |
| vf get(vf x) { | |
| vf y(n_out); | |
| for (int i = 0; i < n_out; i++) { | |
| y[i] = 0.0; | |
| for (int j = 0; j < n_in; j++) { | |
| y[i] += W[i * n_in + j] * x[j]; | |
| } | |
| y[i] += b[i]; | |
| } | |
| return y; | |
| } | |
| vf get_relu(vf x) { | |
| vf y = get(x); | |
| for (int i = 0; i < n_out; i++) { | |
| y[i] = relu(y[i]); | |
| } | |
| return y; | |
| } | |
| }; | |
| int | |
| argmax(vf &v) { | |
| float max = v[0]; | |
| int max_i = 0; | |
| for (int i = 1; i < v.size(); i++) { | |
| if (max < v[i]) { | |
| max_i = i; | |
| max = v[i]; | |
| } | |
| } | |
| return max_i; | |
| } | |
| int | |
| main(void) { | |
| const int n_in = 2; | |
| const int n_units = 3; | |
| const int n_out = 2; | |
| std::ifstream ifs("and.dat"); | |
| Link l1(n_in, n_units), l2(n_units, n_out); | |
| l1.read(ifs); | |
| l2.read(ifs); | |
| float x[4][2] = {{0, 0}, {0, 1}, {1, 0}, {1, 1}}; | |
| for (int i = 0; i < 4; i++) { | |
| vf x2; | |
| x2.push_back(x[i][0]); | |
| x2.push_back(x[i][1]); | |
| vf y = l2.get(l1.get_relu(x2)); | |
| printf("[%f %f] %d: [%f %f]\n", x2[0], x2[1], argmax(y), y[0], y[1]); | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import print_function | |
| import numpy as np | |
| import chainer | |
| import chainer.functions as F | |
| import chainer.links as L | |
| from chainer import training | |
| from chainer.training import extensions | |
| # Network definition | |
| class MLP(chainer.Chain): | |
| def __init__(self, n_units, n_out): | |
| super(MLP, self).__init__( | |
| l1 = L.Linear(None, n_units), | |
| l2 = L.Linear(None, n_out) | |
| ) | |
| def __call__(self, x): | |
| return self.l2(F.relu(self.l1(x))) | |
| # Data Preparation | |
| def make_data(N): | |
| x = np.empty((N,2),dtype=np.float32) | |
| y = np.empty(N,dtype=np.int32) | |
| for i in range(N): | |
| x1 = i%2 | |
| x2 = (i/2)%2 | |
| x[i][0] = x1 | |
| x[i][1] = x2 | |
| y[i] =x1 & x2 | |
| return chainer.datasets.TupleDataset(x,y) | |
| def main(): | |
| epoch = 20 | |
| batchsize = 100 | |
| unit = 3 | |
| model = L.Classifier(MLP(unit, 2)) | |
| optimizer = chainer.optimizers.Adam() | |
| optimizer.setup(model) | |
| test = make_data(100) | |
| train = make_data(10000) | |
| train_iter = chainer.iterators.SerialIterator(train, batchsize) | |
| test_iter = chainer.iterators.SerialIterator(test, batchsize, repeat=False, shuffle=False) | |
| updater = training.StandardUpdater(train_iter, optimizer) | |
| trainer = training.Trainer(updater, (epoch, 'epoch'), out='result') | |
| trainer.extend(extensions.Evaluator(test_iter, model)) | |
| trainer.extend(extensions.dump_graph('main/loss')) | |
| trainer.extend(extensions.snapshot(), trigger=(epoch, 'epoch')) | |
| trainer.extend(extensions.LogReport()) | |
| trainer.extend(extensions.PrintReport( | |
| ['epoch', 'main/loss', 'validation/main/loss', | |
| 'main/accuracy', 'validation/main/accuracy'])) | |
| trainer.extend(extensions.ProgressBar()) | |
| # Training | |
| trainer.run() | |
| chainer.serializers.save_npz('and.model',model) | |
| # Results | |
| x = np.array([[0,0],[0,1],[1,0],[1,1]],dtype=np.float32) | |
| y = model.predictor(x).data | |
| for i in range(4): | |
| print (x[i],np.argmax(y[i]),y[i]) | |
| if __name__ == '__main__': | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment