Last active
September 25, 2023 02:03
-
-
Save cgcardona/2709914b11a7376fcc134000c0e92505 to your computer and use it in GitHub Desktop.
Flask app to serve a GET request at "/" which accepts a `?prompt=my creative and expressive stable diffusion prompt`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# all the imports | |
import io, torch, time, math, os | |
# import specified modules | |
from flask import Flask, request, send_file | |
from torch import autocast | |
from diffusers import StableDiffusionPipeline | |
# create a new flask app | |
app = Flask(__name__) | |
# confirm GPU supports the NVIDIA machine learning toolkit | |
assert torch.cuda.is_available() | |
# Stable Diffusion v1.4: CompVis/stable-diffusion-v1-4 | |
# Stable Diffusion v1.5: runwayml/stable-diffusion-v1-5 | |
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=True).to("cuda") | |
def run_inference(prompt): | |
# get first 20 chars of the prompt to use as the file name | |
first_20_chars = prompt[0:20] | |
# sanitize the chars by removing any `,` characters | |
sanitized_chars = first_20_chars.replace(",", "") | |
# sanitize the chars by replacing ` ` with `-` | |
sanitized_chars = sanitized_chars.replace(" ", "-") | |
with autocast("cuda"): | |
image = pipe(prompt).images[0] | |
img_data = io.BytesIO() | |
image.save(img_data, "PNG") | |
img_data.seek(0) | |
timestamp = math.ceil(time.time()) | |
title = f"{sanitized_chars}-{timestamp}" | |
parent_dir = "/path/to/generated-assets" | |
# each time inference.py is run a new directory should be created which is named | |
# the current timestamp. This new directory is where the newly generated | |
# image should be saved | |
directory = title | |
file_path = os.path.join(parent_dir, directory) | |
# set permissions | |
mode = 0o744 | |
os.mkdir(file_path, mode) | |
file_name = f"{title}.png" | |
file_path_and_name = f"{file_path}/{file_name}" | |
image.save(file_path_and_name) | |
print(f"{file_path_and_name} created!") | |
# success | |
# ✨ 😎 ✨ | |
sparkle = "\U00002728" | |
sunglasses = "\U0001F60E" | |
print (f"{sparkle} {sunglasses} {sparkle}") | |
print ("Winning") | |
return img_data | |
@app.route('/') | |
def myapp(): | |
if "prompt" not in request.args: | |
return "Please specify a prompt parameter", 400 | |
# prompt gets passed in as a query string parameter | |
prompt = request.args["prompt"] | |
img_data = run_inference(prompt) | |
return send_file(img_data, mimetype='image/png') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment