Created
March 28, 2024 21:50
-
-
Save Michael-F-Ellis/9ddfee002d31e92816c34acbb194aaa3 to your computer and use it in GitHub Desktop.
Stable Cascade App code discussed in a question submitted to StackOverflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gradio as gr | |
import spaces | |
from diffusers import StableCascadeCombinedPipeline | |
import os | |
import torch | |
from PIL import Image | |
import random | |
# Constants | |
repo = "stabilityai/stable-cascade" | |
# Ensure model and scheduler are initialized in GPU-enabled function | |
if torch.cuda.is_available(): | |
pipe = StableCascadeCombinedPipeline.from_pretrained(repo, variant="bf16", torch_dtype=torch.bfloat16) | |
pipe.to("cuda") | |
# The generate function | |
@spaces.GPU(enable_queue=True) | |
def generate_image(prompt): | |
#def generate_image(prompt, images): | |
seed = random.randint(-100000,100000) | |
results = pipe( | |
prompt=prompt, | |
#images=[images], | |
height=1024, | |
width=1024, | |
num_inference_steps=20, | |
generator=torch.Generator(device="cuda").manual_seed(seed) | |
) | |
return results.images[0] | |
# ------------- Gradio Interface ----------------------- | |
description = """ | |
This demo utilizes the StableCascade combined pipeline | |
""" | |
with gr.Blocks(css="style.css") as demo: | |
gr.HTML("<h1><center>Via Stable Cascade ⚡</center></h1>") | |
gr.Markdown(description) | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox(label='Enter your prompt', scale=8, value="holding a puppy") | |
submit = gr.Button(scale=1, variant='primary') | |
imgin = gr.Image(label='Input Image', type='pil', height=1024, width=1024, interactive=True) | |
imgout = gr.Image(label='StableCascade Generated Image', height=1024, width=1024) | |
prompt.submit(fn=generate_image, | |
inputs=[prompt], | |
#inputs=[prompt, imgin], | |
outputs=imgout, | |
) | |
submit.click(fn=generate_image, | |
inputs=[prompt], | |
#inputs=[prompt, imgin], | |
outputs=imgout, | |
) | |
demo.queue().launch() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment