Skip to content

Instantly share code, notes, and snippets.

@EKami
Created March 3, 2018 00:34
Show Gist options
  • Save EKami/087cfe603e0886ce4237cb3d55f2e241 to your computer and use it in GitHub Desktop.
Save EKami/087cfe603e0886ce4237cb3d55f2e241 to your computer and use it in GitHub Desktop.
import os
import requests
from io import BytesIO
import Algorithmia
from Algorithmia.acl import ReadAcl
from torchlite.eval import eval
from pathlib import Path
from PIL import Image
import uuid
class AlgorithmError(Exception):
"""Define error handling class."""
def __init__(self, value):
self.value = value
def __str__(self):
return repr(self.value).replace("\\n", "\n")
def initialize_model():
"""
Global init to avoid delays on cold start
"""
generator_model = client.file("data://Ekami/torchlite/Generator.pth").getFile()
return generator_model
# Note that you don't pass in your API key when creating an algorithm
client = Algorithmia.client("sim8qUOyoIbuyDQq7XDb9+ssfrl1")
model = initialize_model()
def apply(input):
"""
Takes a json input in this form:
{
"image_url": "https://www.cnewyork.net/wp-content/uploads/2015/02/GeoffreyWojciechowski3.jpg",
"upscale_factor": "4"
}
Args:
input (dict): The parsed json
Returns:
dict: A dict in the form :
{"sr_image": url, "original_image": url, "upscale_factor": upscale_factor, "version": version}
"""
# Check if the file exists in the user specified data collection.
if "image_url" in input:
# Instantiate a DataDirectory object, set your data URI and call create
srgan_directory = client.dir("data://Ekami/srgan_results")
# Create your data collection if it does not exist
if srgan_directory.exists() is False:
srgan_directory.create(acl=ReadAcl.public)
image_url = input["image_url"]
upscale_factor = input.get("upscale_factor")
if not upscale_factor:
upscale_factor = 4
upscale_factor = int(upscale_factor)
image_response = requests.get(image_url)
# Retrieve input information
image = Image.open(BytesIO(image_response.content))
# Save original image in dir
save_path = Path("./tmp/original_" + Path(image_url).name).absolute() # Local path
image.save(save_path, "png")
client.file("data://" + srgan_directory.path + "/new_image.png").putFile(save_path)
os.remove(save_path)
# Frozen inference graph method:
# https://algorithmia.com/developers/algorithm-development/model-guides/tensorflow/
sr_img = eval.srgan_eval([image], model.buffer, upscale_factor, use_cuda=True)[0]
# Save SR image
save_path = Path("./tmp/sr_" + Path(image_url).name).absolute() # Local path
sr_img.save(save_path, "png")
client.file(srgan_directory.path).putFile(save_path)
os.remove(save_path)
return {"sr_image": sr_url, "upscale_factor": upscale_factor}
else:
# Raise helpful error message
raise AlgorithmError("Please provide a valid image input")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment