Last active
June 3, 2019 07:48
-
-
Save sezemiadmin/61da8e8a987f9e641439eb72653b3f7b to your computer and use it in GitHub Desktop.
Pythonでディープラーニング入門 サンプルコード
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Run the training | |
trainer.run() | |
# ここから書き足す | |
# Save the trained model | |
chainer.serializers.save_npz("trained_mnist.model", model) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import chainer | |
import chainer.functions as F | |
import chainer.links as L | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
class MLP(chainer.Chain): | |
def __init__(self, n_units, n_out): | |
super(MLP, self).__init__() | |
with self.init_scope(): | |
self.l1 = L.Linear(None, n_units) # n_in -> n_units | |
self.l2 = L.Linear(None, n_units) # n_units -> n_units | |
self.l3 = L.Linear(None, n_out) # n_units -> n_out | |
def __call__(self, x): | |
h1 = F.relu(self.l1(x)) | |
h2 = F.relu(self.l2(h1)) | |
return self.l3(h2) | |
model = L.Classifier(MLP(1000, 10)) | |
chainer.serializers.load_npz('trained_mnist.model', model) | |
image = Image.open("number.png").convert('L') | |
plt.imshow(image, cmap='gray') | |
plt.title('input data') | |
plt.show() | |
image = np.asarray(image).astype(np.float32) / 255 | |
image = image.reshape((1, -1)) | |
result = model.predictor(chainer.Variable(image)) | |
print('predicted', ':', np.argmax(result.data)) | |
for i in range(10): | |
print (str(i) , ":" , str(result.data[0,i])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(base) C:\Users\Secollege160405\Desktop\mnist>python predict_mnist.py | |
predicted : 3 | |
0 : -17.519121 | |
1 : -6.611537 | |
2 : -3.5029051 | |
3 : 28.767323 | |
4 : -16.44517 | |
5 : 0.56008315 | |
6 : -28.957706 | |
7 : -5.5690756 | |
8 : -3.870232 | |
9 : -4.1418304 | |
(base) C:\Users\Secollege160405\Desktop\mnist> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
epoch main/loss validation/main/loss main/accuracy validation/main/accuracy elapsed_time | |
1 0.189386 0.088241 0.943733 0.9717 188.366 | |
2 0.0725646 0.0875341 0.977683 0.9714 378.107 | |
3 0.049227 0.0849267 0.984167 0.9735 564.709 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import argparse | |
import chainer | |
import chainer.functions as F | |
import chainer.links as L | |
from chainer import training | |
from chainer.training import extensions | |
# Network definition | |
class MLP(chainer.Chain): | |
def __init__(self, n_units, n_out): | |
super(MLP, self).__init__() | |
with self.init_scope(): | |
# the size of the inputs to each layer will be inferred | |
self.l1 = L.Linear(None, n_units) # n_in -> n_units | |
self.l2 = L.Linear(None, n_units) # n_units -> n_units | |
self.l3 = L.Linear(None, n_out) # n_units -> n_out | |
def forward(self, x): | |
h1 = F.relu(self.l1(x)) | |
h2 = F.relu(self.l2(h1)) | |
return self.l3(h2) | |
# さっきのコマンドでつけたオプション -g -e などの内容 | |
def main(): | |
parser = argparse.ArgumentParser(description='Chainer example: MNIST') | |
parser.add_argument('--batchsize', '-b', type=int, default=100, help='Number of images in each mini-batch') | |
parser.add_argument('--epoch', '-e', type=int, default=20, help='Number of sweeps over the dataset to train') | |
parser.add_argument('--frequency', '-f', type=int, default=-1, help='Frequency of taking a snapshot') | |
parser.add_argument('--gpu', '-g', type=int, default=-1, help='GPU ID (negative value indicates CPU)') | |
parser.add_argument('--out', '-o', default='result', help='Directory to output the result') | |
parser.add_argument('--resume', '-r', default='', help='Resume the training from snapshot') | |
parser.add_argument('--unit', '-u', type=int, default=1000, help='Number of units') | |
parser.add_argument('--noplot', dest='plot', action='store_false', help='Disable PlotReport extension') | |
args = parser.parse_args() | |
print('GPU: {}'.format(args.gpu)) | |
print('# unit: {}'.format(args.unit)) | |
print('# Minibatch-size: {}'.format(args.batchsize)) | |
print('# epoch: {}'.format(args.epoch)) | |
print('') | |
# Model の生成 | |
# Set up a neural network to train | |
# Classifier reports softmax cross entropy loss and accuracy at every | |
# iteration, which will be used by the PrintReport extension below. | |
model = L.Classifier(MLP(args.unit, 10)) # 10 は最終的にアウトプットする値の個数 | |
if args.gpu >= 0: | |
# Make a specified GPU current | |
chainer.backends.cuda.get_device_from_id(args.gpu).use() | |
model.to_gpu() # Copy the model to the GPU | |
# Setup an optimizer | |
optimizer = chainer.optimizers.Adam() | |
optimizer.setup(model) | |
# 以降で Iterator を生成 | |
# Load the MNIST dataset | |
train, test = chainer.datasets.get_mnist() | |
train_iter = chainer.iterators.SerialIterator(train, args.batchsize) | |
test_iter = chainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) | |
# Trainer の生成 | |
# Set up a trainer | |
updater = training.updaters.StandardUpdater(train_iter, optimizer, device=args.gpu) | |
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) | |
# Evaluate the model with the test dataset for each epoch | |
trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu)) | |
# Dump a computational graph from 'loss' variable at the first iteration | |
# The "main" refers to the target link of the "main" optimizer. | |
trainer.extend(extensions.dump_graph('main/loss')) | |
# Take a snapshot for each specified epoch | |
frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) | |
trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) | |
# Write a log of evaluation statistics for each epoch | |
trainer.extend(extensions.LogReport()) | |
# Save two plot images to the result dir | |
if args.plot and extensions.PlotReport.available(): | |
trainer.extend(extensions.PlotReport( | |
['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png')) | |
trainer.extend(extensions.PlotReport( | |
['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png')) | |
# Print selected entries of the log to stdout | |
# Here "main" refers to the target link of the "main" optimizer again, and | |
# "validation" refers to the default name of the Evaluator extension. | |
# Entries other than 'epoch' are reported by the Classifier link, called by | |
# either the updater or the evaluator. | |
trainer.extend(extensions.PrintReport( | |
['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) | |
# Print a progress bar to stdout | |
trainer.extend(extensions.ProgressBar()) | |
if args.resume: | |
# Resume from a snapshot | |
chainer.serializers.load_npz(args.resume, trainer) | |
# Run the training | |
trainer.run() | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
> python train_mnist.py -g -1 -e 3 | |
GPU: -1 | |
# unit: 1000 | |
# Minibatch-size: 100 | |
# epoch: 3 | |
C:\ProgramData\Anaconda3\lib\site-packages\chainer\optimizers\adam.py:111: Runti | |
meWarning: invalid value encountered in sqrt | |
param.data -= hp.eta * (self.lr * m / (numpy.sqrt(vhat) + hp.eps) + | |
epoch main/loss validation/main/loss main/accuracy validation/main/acc | |
uracy elapsed_time | |
C:\ProgramData\Anaconda3\lib\site-packages\chainer\optimizers\adam.py:111: Runti | |
meWarning: invalid value encountered in sqrt | |
param.data -= hp.eta * (self.lr * m / (numpy.sqrt(vhat) + hp.eps) + | |
1 0.189386 0.088241 0.943733 0.9717 188.366 | |
2 0.0725646 0.0875341 0.977683 0.9714 378.107 | |
3 0.049227 0.0849267 0.984167 0.9735 564.709 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment