Created
May 18, 2016 04:11
-
-
Save xlvector/6923ef145e59de44ed06f21228f2f879 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme | |
# pylint: disable=superfluous-parens, no-member, invalid-name | |
import sys | |
sys.path.insert(0, "../../python") | |
import mxnet as mx | |
import numpy as np | |
import cv2, random | |
from io import BytesIO | |
from captcha.image import ImageCaptcha | |
class OCRBatch(object): | |
def __init__(self, data_names, data, label_names, label): | |
self.data = data | |
self.label = label | |
self.data_names = data_names | |
self.label_names = label_names | |
@property | |
def provide_data(self): | |
return [(n, x.shape) for n, x in zip(self.data_names, self.data)] | |
@property | |
def provide_label(self): | |
return [(n, x.shape) for n, x in zip(self.label_names, self.label)] | |
def gen_rand(): | |
num = random.randint(0, 9999) | |
buf = str(num) | |
while len(buf) < 4: | |
buf = "0" + buf | |
return buf | |
def get_label(buf): | |
return np.array([int(x) for x in buf]) | |
class OCRIter(mx.io.DataIter): | |
def __init__(self, count, batch_size, num_label, height, width): | |
super(OCRIter, self).__init__() | |
self.captcha = ImageCaptcha(fonts=['./data/OpenSans-Regular.ttf']) | |
self.batch_size = batch_size | |
self.count = count | |
self.height = height | |
self.width = width | |
self.provide_data = [('data', (batch_size, 3, height, width))] | |
self.provide_label = [('softmax_label', (self.batch_size, num_label))] | |
def __iter__(self): | |
for k in range(self.count / self.batch_size): | |
data = [] | |
label = [] | |
for i in range(self.batch_size): | |
num = gen_rand() | |
img = self.captcha.generate(num) | |
img = np.fromstring(img.getvalue(), dtype='uint8') | |
img = cv2.imdecode(img, cv2.IMREAD_COLOR) | |
img = cv2.resize(img, (self.width, self.height)) | |
cv2.imwrite("./tmp" + str(i % 10) + ".png", img) | |
img = np.multiply(img, 1/255.0) | |
img = img.transpose(2, 0, 1) | |
data.append(img) | |
label.append(get_label(num)) | |
data_all = [mx.nd.array(data)] | |
label_all = [mx.nd.array(label)] | |
data_names = ['data'] | |
label_names = ['softmax_label'] | |
data_batch = OCRBatch(data_names, data_all, label_names, label_all) | |
yield data_batch | |
def reset(self): | |
pass | |
def get_ocrnet(): | |
data = mx.symbol.Variable('data') | |
label = mx.symbol.Variable('softmax_label') | |
conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=32) | |
pool1 = mx.symbol.Pooling(data=conv1, pool_type="max", kernel=(2,2), stride=(1, 1)) | |
relu1 = mx.symbol.Activation(data=pool1, act_type="relu") | |
conv2 = mx.symbol.Convolution(data=relu1, kernel=(5,5), num_filter=32) | |
pool2 = mx.symbol.Pooling(data=conv2, pool_type="avg", kernel=(2,2), stride=(1, 1)) | |
relu2 = mx.symbol.Activation(data=pool2, act_type="relu") | |
conv3 = mx.symbol.Convolution(data=relu2, kernel=(3,3), num_filter=32) | |
pool3 = mx.symbol.Pooling(data=conv3, pool_type="avg", kernel=(2,2), stride=(1, 1)) | |
relu3 = mx.symbol.Activation(data=pool3, act_type="relu") | |
flatten = mx.symbol.Flatten(data = relu3) | |
fc1 = mx.symbol.FullyConnected(data = flatten, num_hidden = 512) | |
fc21 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10) | |
fc22 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10) | |
fc23 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10) | |
fc24 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10) | |
fc2 = mx.symbol.Concat(*[fc21, fc22, fc23, fc24], dim = 0) | |
label = mx.symbol.transpose(data = label) | |
label = mx.symbol.Reshape(data = label, target_shape = (0, )) | |
return mx.symbol.SoftmaxOutput(data = fc2, label = label, name = "softmax") | |
def Accuracy(label, pred): | |
label = label.T.reshape((-1, )) | |
hit = 0 | |
total = 0 | |
for i in range(pred.shape[0] / 4): | |
ok = True | |
for j in range(4): | |
k = i * 4 + j | |
if np.argmax(pred[k]) != int(label[k]): | |
ok = False | |
break | |
if ok: | |
hit += 1 | |
total += 1 | |
return 1.0 * hit / total | |
network = get_ocrnet() | |
devs = [mx.gpu(0)] | |
model = mx.model.FeedForward(ctx = devs, | |
symbol = network, | |
num_epoch = 15, | |
learning_rate = 0.001, | |
wd = 0.00001, | |
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34), | |
momentum = 0.9) | |
data_train = OCRIter(100000, 50, 4, 30, 80) | |
data_test = OCRIter(1000, 50, 4, 30, 80) | |
import logging | |
head = '%(asctime)-15s %(message)s' | |
logging.basicConfig(level=logging.DEBUG, format=head) | |
model.fit(X = data_train, eval_data = data_test, eval_metric = Accuracy, batch_end_callback=mx.callback.Speedometer(32, 50),) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
can you show the pic of your network?