Created
May 13, 2016 07:01
-
-
Save haoliplus/59e036341344e7e7e4163573cbf1e087 to your computer and use it in GitHub Desktop.
Using Caffe model to predict
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python | |
################################################################################# | |
# File Name : test-one.py | |
# Created By : Hao Li | |
# Creation Date : [2016-04-08 13:44] | |
# Last Modified : [2016-04-26 22:24] | |
# Description : | |
################################################################################# | |
import os | |
import sys | |
import numpy | |
sys.path.insert(0, "path-to-caffe-python-lib") | |
import caffe | |
class CaffeModel(): | |
def convert_mean_file(self, mean_filename): | |
# 将binaryproto文件转化为npy文件内容 | |
proto_data = open(mean_filename, "rb").read() | |
blob = caffe.proto.caffe_pb2.BlobProto() | |
blob.ParseFromString(proto_data) | |
npy = numpy.array(caffe.io.blobproto_to_array(blob))[0] | |
return npy | |
def get_mean(self, mean_filename): | |
# 从npy文件内容中计算平均值 | |
npy = self.convert_mean_file(mean_filename) | |
return npy.mean(1).mean(1) | |
def __init__(self, data_name="current"): | |
# 初始化模型 | |
caffe_root = "/mnt/disk0/lihao/plate-test/data/caffe-model" | |
MODEL_FILE = '%s/deploy.prototxt' % (caffe_root) | |
PRETRAINED = '%s/%s/caffe_alexnet_train_iter_2268000.caffemodel' %(caffe_root, data_name) | |
mean_filename = '%s/data/%s/plate_mean.binaryproto' % (caffe_root, data_name) | |
mean = self.get_mean(mean_filename) | |
caffe.set_device(0) | |
caffe.set_mode_gpu() | |
# caffe.Classifier 将 预处理和 Net 的预测封装了一下。 | |
# 会以caffe.TEST模式来进行预测, 这里的raw_scale需要与训练模型采用的像素范围相同, | |
# | |
self.net = caffe.Classifier(MODEL_FILE, PRETRAINED, | |
mean=mean, | |
raw_scale=255, | |
image_dims=(64,64)) | |
def getProb(self, img_path): | |
img = caffe.io.load_image(img_path, color=False) | |
pred = self.net.predict([img], oversample=False) | |
return pred | |
def predict(self, segments): | |
probs = [] | |
imgs = [caffe.ip.load_image(path, color=False) for path in segments] | |
# Oversample会进行角落,中心,镜像的平均值采样 | |
probs = self.net.predict(imgs, oversample=False) | |
probs = numpy.array(probs) | |
return probs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment