Skip to content

Instantly share code, notes, and snippets.

@gavinHuang
Created June 13, 2018 05:55
Show Gist options
  • Save gavinHuang/25bc752b0a05bff9a5bd871ace299ce4 to your computer and use it in GitHub Desktop.
Save gavinHuang/25bc752b0a05bff9a5bd871ace299ce4 to your computer and use it in GitHub Desktop.
cnn model for character detection
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
import numpy as np
import tensorflow as tf
from PIL import Image
import cv2
import os
import math
import sys
tf.logging.set_verbosity(tf.logging.INFO)
def model_fn(features, labels, mode):
input_layer = tf.reshape(features['x'],[-1,28,28,1])
convnet1=tf.layers.conv2d(
inputs=input_layer,
filters=32,
kernel_size=[5,5],
padding="same",
activation=tf.nn.relu)
pool1=tf.layers.max_pooling2d(
inputs = convnet1,
pool_size=[2,2],
strides=2
)
convnet2 = tf.layers.conv2d(
inputs=pool1,
filters=64,
kernel_size=[5,5],
padding="same",
activation=tf.nn.relu)
pool2 = tf.layers.max_pooling2d(
inputs = convnet2,
pool_size=[2,2],
strides=2
)
pool2_flat = tf.reshape(pool2,[-1,7 * 7 * 64])
dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
dropout = tf.layers.dropout(inputs=dense,rate=0.4,training=(mode==tf.estimator.ModeKeys.TRAIN))
logits = tf.layers.dense(inputs=dropout, units=10)
predictions = {
"classes":tf.argmax(input=logits, axis=1),
"probabilities":tf.nn.softmax(logits,name="s_tensor")
}
##predict
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
onehot_labels = tf.one_hot(indices=tf.cast(labels, dtype=tf.int32), depth=10)
loss = tf.losses.softmax_cross_entropy(
onehot_labels=onehot_labels,
logits=logits
)
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001,name="gd")
train_op = optimizer.minimize(loss=loss,global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(train_op = train_op, mode=mode, loss=loss)
#evaluate
eval_metric_ops = {
"accuracy":tf.metrics.accuracy(labels=labels,predictions = predictions["classes"])
}
return tf.estimator.EstimatorSpec(loss=loss, mode=mode,eval_metric_ops=eval_metric_ops)
def main(unused_array):
file_dir="/home/gavin/notMNIST/notMNIST_small"
#65==A
test_labels = []
test_data = []
train_labels = []
train_data = []
i=0
for name in os.listdir(file_dir):
if len(name) > 1:
continue
input_data = []
input_label = []
for file_name in os.listdir(os.path.join(file_dir, name)):
img = Image.open( os.path.join(file_dir, name, file_name) )
img.load()
data = np.asarray(img, dtype="float32")
img.close()
if len(input_data) < 1:
input_data = data.flatten()
else:
input_data = np.vstack((input_data, data.flatten()))
input_label.append(ord(name)-65)
np.random.shuffle(input_data)
idx = math.floor(0.2*len(input_data))
test = input_data[:idx]
train = input_data[ idx:]
# if i == 0:
# print(len(test))
# print(len(test[0]))
# i=i+1
if len(test_data) == 0:
test_data = test
test_labels = input_label[:idx]
train_data = train
train_labels = input_label[idx:]
else:
test_data=np.vstack((test_data,test))
test_labels=np.append(test_labels,input_label[:idx])
train_data=np.vstack((train_data,train))
train_labels=np.append(train_labels,input_label[idx:])
# print(len(test_data))
# print(len(test_labels))
# print(len(train_data))
# print(len(train_labels))
tensors_to_log = {"probabilities":"s_tensor"}
logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log,every_n_iter=50)
character_estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir="/tmp/practice")
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x":train_data},
y=train_labels,
batch_size=50,
num_epochs=None,
shuffle=True
)
character_estimator.train(input_fn=train_input_fn,steps=20000,hooks=[logging_hook])
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x":test_data},
y=test_labels,
num_epochs=1,
shuffle=False)
eval_results = character_estimator.evaluate(input_fn = eval_input_fn)
print(eval_results)
if __name__ == "__main__":
tf.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment