Last active
September 7, 2017 20:06
-
-
Save zeryx/ce2c7b620f8141d5ad5e66200954afd1 to your computer and use it in GitHub Desktop.
performance improvements with tensorflow & algorithmia
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This code example describes how to pre-load a tensorflow graph file | |
# into an Algorithmia container and load the graph into memory. | |
# This approach allows us to preserve the graph in system memory between API calls, | |
# Improving overall performance. | |
# We also document how to evict the tensorflow GPU memory context between API requests by | |
# moving it to a separate thread, and how to define the amount of GPU memory an algorithm uses. | |
# These GPU tweaks significantly improve performance on Algorithmia's infrastructure | |
import Algorithmia | |
import multiprocessing | |
import numpy as np | |
import tensorflow as tf | |
import tarfile | |
client = Algorithmia.client() | |
## we define the graph and category index in advance, and provide them with default values. | |
GRAPH, CATEGORY_INDEX, MODEL_NAME = (tf.Graph(), '', 'default') | |
def get_image(url): | |
output_url = client.algo(SIMD_ALGO).pipe({'image': str(url)}).result['savePath'][0] | |
temp_file = client.file(output_url).getFile().name | |
os.rename(temp_file, temp_file + '.' + output_url.split('.')[-1]) | |
return temp_file + '.' + output_url.split('.')[-1] | |
# This function assumes that your model is in a tar.gz format, if otherwise extract your file accordingly. | |
def download_model(model_name): | |
global GRAPH | |
global CAT_INDEX | |
download_base = 'data://some/base/' | |
model_file = model_name + '.tar.gz' | |
# Path to frozen detection graph. This is the actual model that is used for the object detection. | |
path_to_graph = model_name + '/frozen_inference_graph.pb' | |
# List of the strings that is used to add correct label for each box. | |
path_to_labels = client.file("data://path/to/labels.pbtxt").getFile().name | |
print(model_name) | |
print(MODEL_NAME) | |
if model_name != MODEL_NAME: | |
print('model name not the same, reloading...') | |
if not os.path.isfile(path_to_graph): | |
try: | |
local_file = client.file(download_base+model_file).getFile().name | |
except Exception: | |
raise AlgorithmError("AlgoError3000: invalid model name.") | |
tar_file = tarfile.open(local_file) | |
for file in tar_file.getmembers(): | |
file_name = os.path.basename(file.name) | |
if 'frozen_inference_graph.pb' in file_name: | |
tar_file.extract(file, os.getcwd()) | |
detection_graph = tf.Graph() | |
with detection_graph.as_default(): | |
od_graph_def = tf.GraphDef() | |
with tf.gfile.GFile(path_to_graph, 'rb') as fid: | |
serialized_graph = fid.read() | |
od_graph_def.ParseFromString(serialized_graph) | |
tf.import_graph_def(od_graph_def, name='') | |
label_map = label_map_util.load_labelmap(path_to_labels) | |
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) | |
category_index = label_map_util.create_category_index(categories) | |
## Set the global variables so they can be used later without repeating this process. | |
GRAPH = detection_graph | |
CATEGORY_INDEX = category_index | |
MODEL_NAME = model_name | |
## We don't activate allow_growth as having well defined gpu memory use profiles helps in scheduling. | |
## It also offers a slight performance improvement, which is always nice :) | |
def generate_gpu_config(memory_fraction): | |
config = tf.ConfigProto() | |
# config.gpu_options.allow_growth = True | |
config.gpu_options.per_process_gpu_memory_fraction = memory_fraction | |
return config | |
# Since this function is never used outside of a multiprocessing Process, it returns nothing. | |
# However, it does mutate the result object defined in apply. | |
def execute_tensorflow(graph, image, | |
category_index, max_boxes, min_score, result): | |
with graph.as_default(): | |
with tf.Session(graph=GRAPH, config=generate_gpu_config(GPU_MEMORY_FRACTION)) as sess: | |
image = Image.open(image).convert('RGB') | |
image_np = load_image_into_numpy_array(image) | |
height, width, _ = image_np.shape | |
image_np_expanded = np.expand_dims(image_np, axis=0) | |
image_tensor = graph.get_tensor_by_name('image_tensor:0') | |
boxes = graph.get_tensor_by_name('detection_boxes:0') | |
scores = graph.get_tensor_by_name('detection_scores:0') | |
classes = graph.get_tensor_by_name('detection_classes:0') | |
num_detections = graph.get_tensor_by_name('num_detections:0') | |
(boxes, scores, classes, num_detections) = sess.run( | |
[boxes, scores, classes, num_detections], | |
feed_dict={image_tensor: image_np_expanded}) | |
boxes = np.squeeze(boxes) | |
classes = np.squeeze(classes).astype(np.int32) | |
scores = np.squeeze(scores) | |
prepare_output(height, width, boxes, classes, scores, category_index, result) | |
def apply(input): | |
model_name = "some_default_model_file" | |
if isinstance(input, str): | |
image = get_image(input) | |
elif isinstance(input, dict): | |
# - process your input fields here - | |
if 'model' in input: | |
model_name = input['model'] | |
download_model(model_name) | |
# don't forget to put all of the completed work into a multiprocessing friendly structure, like this list format. | |
result = multiprocessing.Manager().list() | |
# execute_tensorflow is run in a separate thread so that when the job is complete we can kill the GPU context | |
p = multiprocessing.Process(target=execute_tensorflow, | |
args=(GRAPH, image, | |
CATEGORY_INDEX, result)) | |
p.start() | |
p.join() | |
result = [x for x in result] | |
if output: | |
im = Image.open(image).convert('RGB') | |
image = transform_image(path, output, box_output) | |
return {'data': box_output, 'image': image} | |
else: | |
return {'data': box_output} | |
download_model("some_default_model_file") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment