Created
March 3, 2018 00:34
-
-
Save EKami/087cfe603e0886ce4237cb3d55f2e241 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import requests | |
from io import BytesIO | |
import Algorithmia | |
from Algorithmia.acl import ReadAcl | |
from torchlite.eval import eval | |
from pathlib import Path | |
from PIL import Image | |
import uuid | |
class AlgorithmError(Exception): | |
"""Define error handling class.""" | |
def __init__(self, value): | |
self.value = value | |
def __str__(self): | |
return repr(self.value).replace("\\n", "\n") | |
def initialize_model(): | |
""" | |
Global init to avoid delays on cold start | |
""" | |
generator_model = client.file("data://Ekami/torchlite/Generator.pth").getFile() | |
return generator_model | |
# Note that you don't pass in your API key when creating an algorithm | |
client = Algorithmia.client("sim8qUOyoIbuyDQq7XDb9+ssfrl1") | |
model = initialize_model() | |
def apply(input): | |
""" | |
Takes a json input in this form: | |
{ | |
"image_url": "https://www.cnewyork.net/wp-content/uploads/2015/02/GeoffreyWojciechowski3.jpg", | |
"upscale_factor": "4" | |
} | |
Args: | |
input (dict): The parsed json | |
Returns: | |
dict: A dict in the form : | |
{"sr_image": url, "original_image": url, "upscale_factor": upscale_factor, "version": version} | |
""" | |
# Check if the file exists in the user specified data collection. | |
if "image_url" in input: | |
# Instantiate a DataDirectory object, set your data URI and call create | |
srgan_directory = client.dir("data://Ekami/srgan_results") | |
# Create your data collection if it does not exist | |
if srgan_directory.exists() is False: | |
srgan_directory.create(acl=ReadAcl.public) | |
image_url = input["image_url"] | |
upscale_factor = input.get("upscale_factor") | |
if not upscale_factor: | |
upscale_factor = 4 | |
upscale_factor = int(upscale_factor) | |
image_response = requests.get(image_url) | |
# Retrieve input information | |
image = Image.open(BytesIO(image_response.content)) | |
# Save original image in dir | |
save_path = Path("./tmp/original_" + Path(image_url).name).absolute() # Local path | |
image.save(save_path, "png") | |
client.file("data://" + srgan_directory.path + "/new_image.png").putFile(save_path) | |
os.remove(save_path) | |
# Frozen inference graph method: | |
# https://algorithmia.com/developers/algorithm-development/model-guides/tensorflow/ | |
sr_img = eval.srgan_eval([image], model.buffer, upscale_factor, use_cuda=True)[0] | |
# Save SR image | |
save_path = Path("./tmp/sr_" + Path(image_url).name).absolute() # Local path | |
sr_img.save(save_path, "png") | |
client.file(srgan_directory.path).putFile(save_path) | |
os.remove(save_path) | |
return {"sr_image": sr_url, "upscale_factor": upscale_factor} | |
else: | |
# Raise helpful error message | |
raise AlgorithmError("Please provide a valid image input") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment