Skip to content

Instantly share code, notes, and snippets.

@simon-mo
Last active March 14, 2019 23:56
Show Gist options
  • Select an option

  • Save simon-mo/f8abc1711eb4102ec2b9331d14d1b333 to your computer and use it in GitHub Desktop.

Select an option

Save simon-mo/f8abc1711eb4102ec2b9331d14d1b333 to your computer and use it in GitHub Desktop.
from starlette.applications import Starlette
from starlette.responses import JSONResponse
import uvicorn
from PIL import Image
import numpy as np
import base64
from io import BytesIO
def arr_to_png_str(arr):
im = Image.fromarray(arr)
bytestream = BytesIO()
im.save(bytestream, "PNG")
b64_encoded_str = base64.b64encode(bytestream.getvalue()).decode()
return b64_encoded_str
app = Starlette()
@app.route('/api/query/{model_id}/{model_version}')
async def homepage(request):
id = request.path_params["model_id"]
ver = request.path_params["model_version"]
data = await request.json()
im = data["data"]
image = Image.open(BytesIO(base64.b64decode(im.encode())))
arr = np.array(image)
processed = (arr/2).astype("uint8")
print(processed.shape)
resp = {
"model": id,
"version": ver,
"type": "image/png",
"data": arr_to_png_str(processed)
}
return JSONResponse(resp)
if __name__ == '__main__':
uvicorn.run(app, host='0.0.0.0', port=8000)
FROM python:3.6
RUN pip install pillow starlette[full] uvicorn numpy
COPY app.py test.py /
CMD ["uvicorn", "app:app"]
import requests
import numpy as np
from PIL import Image
from io import BytesIO
import base64
def arr_to_png_str(arr):
im = Image.fromarray(arr).convert("RGB")
bytestream = BytesIO()
im.save(bytestream, "PNG")
b64_encoded_str = base64.b64encode(bytestream.getvalue()).decode()
return b64_encoded_str
request_arr = np.random.randint(0,255, size=(444,444, 3)).astype("uint8")
response_arr = request_arr / 2
response_arr = response_arr.astype("uint8")
MODEL_ID = 111
MODEL_VERSION = 2.1
request_json_data = {
"type": "image/png",
"data": arr_to_png_str(request_arr)
}
expected_json_data = {
"model": MODEL_ID,
"version": MODEL_VERSION,
"type": "image/png",
"data": arr_to_png_str(response_arr)
}
resp = requests.get(f"http://localhost:8000/api/query/{MODEL_ID}/{MODEL_VERSION}", json=request_json_data)
assert(resp.json()["data"] == expected_json_data["data"])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment