Created
September 30, 2016 06:15
-
-
Save macrat/8807b987bee11d079dd51710e1b89cae to your computer and use it in GitHub Desktop.
CNNで特定物体認識
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import chainer | |
import numpy | |
import sklearn.base | |
class CNNFinder(chainer.FunctionSet): | |
def __init__(self): | |
super().__init__(conv1=chainer.functions.Convolution2D(1, 10, 4), | |
conv2=chainer.functions.Convolution2D(10, 20, 4), | |
conv3=chainer.functions.Convolution2D(20, 40, 4), | |
conv4=chainer.functions.Convolution2D(40, 80, 4), | |
l1=chainer.functions.Linear(320, 160), | |
l2=chainer.functions.Linear(160, 2)) | |
self.optimizer = chainer.optimizers.MomentumSGD(lr=0.01, momentum=0.9) | |
self.optimizer.setup(self) | |
def forward(self, x): | |
for conv in (self.conv1, self.conv2, self.conv3, self.conv4): | |
x = chainer.functions.max_pooling_2d( | |
chainer.functions.relu(conv(x)), | |
2, | |
) | |
return self.l2(chainer.functions.relu(self.l1(x))) | |
def train(self, xs, ys): | |
for x, y in zip(xs, ys): | |
x = chainer.Variable(x, volatile=False) | |
y = chainer.Variable(y, volatile=False) | |
h = self.forward(x) | |
self.optimizer.zero_grads() | |
error = chainer.functions.softmax_cross_entropy(h, y) | |
error.backward() | |
self.optimizer.update() | |
class SkCNNFinder(CNNFinder, sklearn.base.BaseEstimator, sklearn.base.ClassifierMixin): | |
def __init__(self, loop_num=100, verbose=False): | |
super().__init__() | |
self.loop_num = loop_num | |
self.verbose = verbose | |
def fit(self, X, y): | |
y = numpy.array([[x] for x in y]) | |
for i in range(self.loop_num): | |
self.train(X, y) | |
if self.verbose: | |
print('{0}/{1} ({2:.1%})'.format(i + 1, | |
self.loop_num, | |
(i + 1) / self.loop_num)) | |
return self | |
def predict(self, xs): | |
return numpy.array([self.forward(x).data.argmax() for x in xs]) | |
if __name__ == '__main__': | |
import cv2 | |
import skimage.feature | |
import os | |
import sklearn.cross_validation | |
def pre_process(img): | |
img = cv2.cvtColor(cv2.resize(img, (64, 64)), cv2.COLOR_BGR2GRAY) | |
return numpy.array(img.reshape((1, 1, 64, 64)) / 255, numpy.float32) | |
print('load image') | |
corrects = [pre_process(cv2.imread('disks/' + x)) for x in os.listdir('disks/')] | |
wrongs = [pre_process(cv2.imread('ng/' + x)) for x in os.listdir('ng/')] | |
print('concat') | |
data = numpy.array(corrects + wrongs) | |
answer = numpy.array([1 for i in corrects] + [0 for i in wrongs], numpy.int32) | |
del corrects | |
del wrongs | |
print('starting learn') | |
score = sklearn.cross_validation.cross_val_score( | |
SkCNNFinder(verbose=True), | |
data, | |
answer, | |
cv=5, | |
verbose=3, | |
) | |
print() | |
print() | |
print('{0:.2%} +/-{1:.4f}'.format(score.mean(), score.std()*2)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment