Last active
July 20, 2019 02:40
-
-
Save Quorafind/a0d07b700b2fa2e91e487c074f45cc2d to your computer and use it in GitHub Desktop.
加载.pb文件并检测
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
from PIL import Image | |
FILEPATH = path_to_img # 检测文件地址 | |
MODELPATH = path_to_pb_file # pb文件 | |
SIZE = [100, 100] # 检测图片规格 | |
DIC = {0: 'aaa', 1: 'bbb'} # 定义字典,保存标签和类别的映射关系 | |
# 图片转化为数组 | |
def get_check_data(path, size): | |
# 检测 | |
def check(filename, modelname, size): | |
data = get_check_data(filename, size) | |
with tf.Graph().as_default(): | |
output_graph_def = tf.GraphDef() | |
with open(modelname, "rb") as f: # 读取pb文件 | |
output_graph_def.ParseFromString(f.read()) | |
_ = tf.import_graph_def(output_graph_def, name="") | |
sess = tf.InteractiveSession() | |
init = tf.global_variables_initializer() | |
sess.run(init) | |
input_x = sess.graph.get_tensor_by_name("input:0") | |
keep_prob = sess.graph.get_tensor_by_name("keep_prob:0") | |
out_y = sess.graph.get_tensor_by_name("fc3:0") | |
softmax = tf.nn.softmax(out_y) | |
y = sess.run(softmax, feed_dict={input_x: data, keep_prob: 1.}) | |
prediction_labels = np.argmax(y) | |
print("label:", prediction_labels) | |
print("this is a %s" % DIC[prediction_labels]) | |
sess.close() | |
def main(_): | |
check(FILEPATH, MODELPATH, SIZE) | |
if __name__ == '__main__': | |
main(None) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment