Last active
November 26, 2019 21:02
-
-
Save zeryx/6be9395c8493472bae9717592c109697 to your computer and use it in GitHub Desktop.
a non-trival algorithm that recursively calls itself when it recieves a large batch of work requests.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
from multiprocessing import Manager, Pool | |
import Algorithmia | |
import os | |
import re | |
# This is code for most tensorflow image classification algorithms | |
# In this example we look at solving batch processing problems with algorithm recursion and pipelining. | |
client = Algorithmia.client() | |
TEMP_COLLECTION = 'data://.session/' | |
SIMD_ALGO = "util/SmartImageDownloader/0.2.14" | |
MODEL_FILE = "data://zeryx/InceptionNetDemo/classify_image_graph_def.pb" | |
CONVERSION_FILE = "data://zeryx/InceptionNetDemo/imagenet_synset_to_human_label_map.txt" | |
LABEL_FILE = "data://zeryx/InceptionNetDemo/imagenet_2012_challenge_label_map_proto.pbtxt" | |
#-- IMPORANT --# be aware of the algorithm version you're calling, as this is self-referential while you're doing development you may need to replace this variable with a version hash. | |
# TODO: We'll improve this experience in the future | |
THIS_ALGO = "zeryx/recursive_image_example/0.1.x" | |
# The number of recursive requests that will be open at any time, this keeps us from overwelming the development environment by constraining our resources to some reasonable maximum. | |
NUM_PARALLEL_REQUESTS = 10 | |
class AlgorithmError(Exception): | |
def __init__(self, value): | |
self.value = value | |
def __str__(self): | |
return repr(self.value) | |
def load_model(): | |
path_to_labels = client.file(LABEL_FILE).getFile().name | |
path_to_model = client.file(MODEL_FILE).getFile().name | |
path_to_conversion = client.file(CONVERSION_FILE).getFile().name | |
detection_graph = tf.Graph() | |
with detection_graph.as_default(): | |
graph_def = tf.GraphDef() | |
with tf.gfile.GFile(path_to_model, 'rb') as fid: | |
serialized_graph = fid.read() | |
graph_def.ParseFromString(serialized_graph) | |
tf.import_graph_def(graph_def, name='') | |
label_index = load_label_index(path_to_conversion, path_to_labels) | |
return detection_graph, label_index | |
def load_label_index(conversion_path, label_path): | |
with open(conversion_path) as f: | |
proto_as_ascii_lines = f.read().split('\n')[:-1] | |
uid_to_human = {} | |
p = re.compile(r'[n\d]*[ \S,]*') | |
for line in proto_as_ascii_lines: | |
parsed_items = p.findall(line) | |
uid = parsed_items[0] | |
human_string = parsed_items[2] | |
uid_to_human[uid] = human_string | |
node_id_to_uid = {} | |
proto_as_ascii = tf.gfile.GFile(label_path).readlines() | |
for line in proto_as_ascii: | |
if line.startswith(' target_class:'): | |
target_class = int(line.split(': ')[1]) | |
if line.startswith(' target_class_string:'): | |
target_class_string = line.split(': ')[1] | |
node_id_to_uid[target_class] = target_class_string[1:-2] | |
# Loads the final mapping of integer node ID to human-readable string | |
node_id_to_name = {} | |
for key, val in node_id_to_uid.items(): | |
if val not in uid_to_human: | |
tf.logging.fatal('Failed to locate: %s', val) | |
name = uid_to_human[val] | |
node_id_to_name[key] = name | |
return node_id_to_name | |
def id_to_string(index_file, node_id): | |
if node_id not in index_file: | |
return '' | |
return index_file[node_id] | |
def get_image(url): | |
output_url = client.algo(SIMD_ALGO).pipe({'image': str(url)}).result['savePath'][0] | |
temp_file = client.file(output_url).getFile().name | |
os.rename(temp_file, temp_file + '.' + output_url.split('.')[-1]) | |
return temp_file + '.' + output_url.split('.')[-1] | |
def inference(image): | |
image_data = tf.gfile.FastGFile(image, 'rb').read() | |
with tf.Session(graph=graph) as sess: | |
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0') | |
predictions = sess.run(softmax_tensor, | |
{'DecodeJpeg/contents:0': image_data}) | |
predictions = np.squeeze(predictions) | |
tags = [] | |
top_k = predictions.argsort()[-5:][::-1] | |
for node_id in top_k: | |
human_string = id_to_string(label_index, node_id) | |
score = predictions[node_id] | |
result = {} | |
result['class'] = human_string | |
result['confidence'] = score.item() | |
tags.append(result) | |
results = {} | |
results['tags'] = tags | |
return results | |
def algorithm_recursion_(input, errorQ): | |
"""This function will create a threadpool and make parallel calls to _algo, and return a callback. | |
As you can see, we limit the pool size by some value to ensure we don't overload anthing. | |
Besides that, we also blend the errorQ object into each chunk of data that we're passing into _algo. | |
If desired, a pool.starmap() can simplify this process.""" | |
pool = Pool(NUM_PARALLEL_REQUESTS) | |
chunks = _chunks(input, 5) | |
process_data = [(chunk, errorQ) for chunk in chunks] | |
async_ops = pool.starmap_async(_algo, process_data) | |
return async_ops | |
def _chunks(l, n): | |
"""Yield successive n-sized chunks from l.""" | |
for i in range(0, len(l), n): | |
yield l[i:i + n] | |
def _algo(algo_data, errorQ): | |
"""The primary working algorithm for our parallel threads, makes parallel requests and checks if errors exist""" | |
try: | |
if errorQ.empty(): | |
print("processing chunk..") | |
response = client.algo(THIS_ALGO).pipe(algo_data).result | |
print("finished chunk..") | |
return response | |
else: | |
return None | |
except Exception as e: | |
errorQ.put(e) | |
def batch_apply(input): | |
"""Simple sequential small batch processing, can be made parallel if necessary""" | |
results = [] | |
for image in input: | |
results.append(apply(image)) | |
return results | |
def apply(input): | |
if isinstance(input, str): | |
image = get_image(input) | |
results = {"image": input, "results": inference(image)} | |
elif isinstance(input, dict) and "image" in input: | |
image = get_image(input['image']) | |
results = {"image": input['image'], "results": inference(image)} | |
elif isinstance(input, list): | |
# If we do have a small list, it doesn't make sense to send off each request to a different machine, | |
# it might just be easier to process it here. | |
if len(input) < 5: | |
results = batch_apply(input) | |
else: | |
# Lets take some work for this algorithm to work on, before we pass the remainder to our recursively | |
# called algorithms | |
input_for_this_worker = input[:5] | |
remaining_work = input[5:] | |
# This object allows us to pass error messages and exceptions between threads, which can be very useful | |
# when things don't go as planned | |
manager = Manager() | |
errorQ = manager.Queue() | |
# We spin off the recursive / threading components of the algorithm to a separate thread so that we can | |
# process this algorithm's batch of work concurrently | |
eventual_remote_results = algorithm_recursion_(remaining_work, errorQ) | |
local_results = batch_apply(input_for_this_worker) | |
concurrent_results = eventual_remote_results.get() | |
# Make sure to check your error Q before returning a result, if it has errors we should return them as | |
# something went wrong | |
if errorQ.empty(): | |
results = local_results + concurrent_results | |
else: | |
raise errorQ.get() | |
else: | |
raise Exception("Input format invalid") | |
return results | |
graph, label_index = load_model() | |
if __name__ == "__main__": | |
input = ["https://i.imgur.com/AJaSkUL.jpg", "https://i.imgur.com/AJaSkUL.jpg", "https://i.imgur.com/AJaSkUL.jpg", | |
"https://i.imgur.com/AJaSkUL.jpg", "https://i.imgur.com/AJaSkUL.jpg", "https://i.imgur.com/AJaSkUL.jpg", | |
"https://i.imgur.com/AJaSkUL.jpg", "https://i.imgur.com/AJaSkUL.jpg", "https://i.imgur.com/AJaSkUL.jpg", | |
"https://i.imgur.com/AJaSkUL.jpg", "https://i.imgur.com/AJaSkUL.jpg", "https://i.imgur.com/AJaSkUL.jpg", | |
"https://i.imgur.com/AJaSkUL.jpg", "https://i.imgur.com/AJaSkUL.jpg", "https://i.imgur.com/AJaSkUL.jpg", | |
"https://i.imgur.com/AJaSkUL.jpg", "https://i.imgur.com/AJaSkUL.jpg", "https://i.imgur.com/AJaSkUL.jpg", | |
"https://i.imgur.com/AJaSkUL.jpg", "https://i.imgur.com/AJaSkUL.jpg", "https://i.imgur.com/AJaSkUL.jpg", | |
"https://i.imgur.com/AJaSkUL.jpg"] | |
final = apply(input) | |
print(final) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment