Last active
March 1, 2018 16:44
-
-
Save EKami/fc078fc0d3274d94a3c03364e4b68bd5 to your computer and use it in GitHub Desktop.
Code 1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import requests | |
from io import BytesIO | |
import Algorithmia | |
from Algorithmia.acl import ReadAcl | |
from torchlite.eval import eval | |
from pathlib import Path | |
from PIL import Image | |
import uuid | |
# Note that you don't pass in your API key when creating an algorithm | |
client = Algorithmia.client("sim8qUOyoIbuyDQq7XDb9+ssfrl1") | |
class AlgorithmError(Exception): | |
"""Define error handling class.""" | |
def __init__(self, value): | |
self.value = value | |
def __str__(self): | |
return repr(self.value).replace("\\n", "\n") | |
def apply(input): | |
""" | |
Takes a json input in this form: | |
{ | |
"image_url": "https://www.cnewyork.net/wp-content/uploads/2015/02/GeoffreyWojciechowski3.jpg", | |
"upscale_factor": "4" | |
} | |
Args: | |
input (dict): The parsed json | |
Returns: | |
dict: A dict in the form : | |
{"sr_image": url, "original_image": url, "upscale_factor": upscale_factor, "version": version} | |
""" | |
# Check if the file exists in the user specified data collection. | |
if "image_url" in input: | |
# Instantiate a DataDirectory object, set your data URI and call create | |
srgan_directory = client.dir("data://Ekami/srgan_results") | |
# Create your data collection if it does not exist | |
if srgan_directory.exists() is False: | |
srgan_directory.create(acl=ReadAcl.my_algos) | |
image_url = input["image_url"] | |
upscale_factor = input.get("upscale_factor") | |
if not upscale_factor: | |
upscale_factor = 4 | |
upscale_factor = int(upscale_factor) | |
image_response = requests.get(image_url) | |
# Create unique dir | |
unique_dir = uuid.uuid4().hex | |
srgan_directory = client.dir("data://Ekami/srgan_results/" + unique_dir) | |
srgan_directory.create(acl=ReadAcl.public) | |
# Retrieve input information | |
generator_model = client.file("data://Ekami/torchlite/Generator.pth").path | |
image = Image.open(BytesIO(image_response.content)) | |
# Save original image in dir | |
save_path = Path("./tmp/original_" + Path(image_url).name).absolute() # Local path | |
image.save(save_path, "png") | |
client.file(srgan_directory.path).putFile(save_path) | |
os.remove(save_path) | |
# Frozen inference graph method: | |
# https://algorithmia.com/developers/algorithm-development/model-guides/tensorflow/ | |
sr_img = eval.srgan_eval([image], generator_model, upscale_factor, use_cuda=True)[0] | |
sr_img.save(tmp_file.path, "png") | |
sr_url = tmp_file.putFile(sr_img) | |
return {"sr_image": sr_url, "upscale_factor": upscale_factor} | |
else: | |
# Raise helpful error message | |
raise AlgorithmError("Please provide a valid image input") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment