Skip to content

Instantly share code, notes, and snippets.

@rohithreddy
Forked from cloneofsimo/vibecheckgen.py
Created September 8, 2024 16:45
Show Gist options
  • Save rohithreddy/badf8ffd239b7effe1084c0847f70106 to your computer and use it in GitHub Desktop.
Save rohithreddy/badf8ffd239b7effe1084c0847f70106 to your computer and use it in GitHub Desktop.
vibecheckgen.py
import requests
import json
import base64
from PIL import Image
from io import BytesIO
import os
import openai
from openai import OpenAI
import fal_client
# Constants
BASE_URL = "http://localhost:8000"
STABILITY_API_KEY = #
openai.api_key = #
def generate_auraflow(prompt, steps=50, scale=7.0, height=1024, width=1024,
use_self_as_nullcond=False, use_student=False):
url = f"{BASE_URL}/generate"
payload = {
"prompt": prompt,
"steps": steps,
"scale": scale,
"height": height,
"width": width,
"use_self_as_nullcond": use_self_as_nullcond,
"use_student": use_student
}
response = requests.post(url, json=payload)
if response.status_code == 200:
result = response.json()
images = []
for img_str in result['images']:
img = Image.open(BytesIO(base64.b64decode(img_str)))
images.append(img)
return images[0] # Return only the first image
else:
print(f"AuraFlow Error: {response.status_code}")
print(response.text)
return None
def generate_sd3l(prompt):
response = requests.post(
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
headers={
"authorization": f"Bearer {STABILITY_API_KEY}",
"accept": "image/*"
},
files={"none": ''},
data={
"prompt": prompt,
"mode": "text-to-image",
"aspect_ratio": "1:1",
"output_format": "png",
"model": "sd3",
},
)
if response.status_code == 200:
return Image.open(BytesIO(response.content))
else:
print(f"SD3-L Error: {response.status_code}")
print(response.text)
return None
def generate_sd3m(prompt, model_name = "fal-ai/stable-diffusion-v3-medium"):
try:
handler = fal_client.submit(
model_name,
arguments={
"prompt": prompt,
"num_inference_steps": 50,
},
)
result = handler.get()
result_url = result["images"][0]['url']
response = requests.get(result_url)
return Image.open(BytesIO(response.content))
except Exception as e:
print(f"SD3-M Error: {str(e)}")
return None
def generate_cascade(prompt, model_name):
try:
handler = fal_client.submit(
model_name,
arguments={
"prompt": prompt,
},
)
result = handler.get()
result_url = result["images"][0]['url']
response = requests.get(result_url)
return Image.open(BytesIO(response.content))
except Exception as e:
print(f"SD3-M Error: {str(e)}")
return None
def generate_sdxl(prompt, model_name):
try:
handler = fal_client.submit(
"fal-ai/fast-sdxl",
arguments={
"prompt": prompt,
"num_inference_steps": 50
},
)
result = handler.get()
result_url = result["images"][0]['url']
response = requests.get(result_url)
return Image.open(BytesIO(response.content))
except Exception as e:
print(f"SD3-M Error: {str(e)}")
return None
def generate_dalle3(prompt):
client = OpenAI(api_key=openai.api_key)
response = client.images.generate(
model="dall-e-3",
prompt=prompt,
size="1024x1024",
quality="standard",
n=1,
)
image_url = response.data[0].url
response = requests.get(image_url)
return Image.open(BytesIO(response.content))
def save_image(image, filename):
if image:
image.save(filename)
print(f"Saved: {filename}")
else:
print(f"Failed to save: {filename}")
def generate_and_save_images(prompts):
os.makedirs("outputs", exist_ok=True)
rootdir = "/home/ubuntu/geneval/image-gallery/public/gallery"
image_data = {}
for idx, prompt in enumerate(prompts):
print(f"Generating images for prompt {idx + 1}: {prompt}")
# AuraFlow
auraflow_image = generate_auraflow(prompt)
save_image(auraflow_image, f"{rootdir}/prompt_{idx + 1}_auraflow.png")
# SD3-L
sd3l_image = generate_sd3l(prompt)
save_image(sd3l_image, f"{rootdir}/prompt_{idx + 1}_sd3l.png")
# SD3-M
sd3m_image = generate_sd3m(prompt, model_name = "fal-ai/stable-diffusion-v3-medium")
save_image(sd3m_image, f"{rootdir}/prompt_{idx + 1}_sd3m.png")
sdcascade = generate_cascade(prompt, model_name = "fal-ai/stable-cascade")
save_image(sdcascade, f"{rootdir}/prompt_{idx + 1}_sdcascade.png")
# sdxl
sdxl_image = generate_sdxl(prompt, model_name = "fal-ai/fast-sdxl")
save_image(sdxl_image, f"{rootdir}/prompt_{idx + 1}_sdxl.png")
# dalle-3
dalle3_image = generate_dalle3(prompt)
save_image(dalle3_image, f"{rootdir}/prompt_{idx + 1}_dalle3.png")
print(f"Completed prompt {idx + 1}\n")
image_data[f"gallery/prompt_{idx+1}_auraflow"] = {
"description": prompt,
"model": "auraflow"
}
image_data[f"gallery/prompt_{idx+1}_sd3l"] = {
"description": prompt,
"model": "sd3l"
}
image_data[f"gallery/prompt_{idx+1}_sd3m"] = {
"description": prompt,
"model": "sd3m"
}
image_data[f"gallery/prompt_{idx+1}_sdcascade"] = {
"description": prompt,
"model": "sdcascade"
}
image_data[f"gallery/prompt_{idx+1}_sdxl"] = {
"description": prompt,
"model": "sdxl"
}
image_data[f"gallery/prompt_{idx+1}_dalle3"] = {
"description": prompt,
"model": "dalle3"
}
with open(f"{rootdir}/image_data.json", "w") as f:
json.dump(image_data, f)
# load all the images and save webp, low-res format.
for idx, prompt in enumerate(prompts):
auraflow_image = Image.open(f"{rootdir}/prompt_{idx + 1}_auraflow.png")
auraflow_image.save(f"{rootdir}/prompt_{idx + 1}_auraflow.webp", "WEBP", quality=5)
auraflow_image.save(f"{rootdir}/prompt_{idx + 1}_auraflow.higher.webp", "WEBP", quality=30)
sd3l_image = Image.open(f"{rootdir}/prompt_{idx + 1}_sd3l.png")
sd3l_image.save(f"{rootdir}/prompt_{idx + 1}_sd3l.webp", "WEBP", quality=5)
sd3l_image.save(f"{rootdir}/prompt_{idx + 1}_sd3l.higher.webp", "WEBP", quality=30)
sd3m_image = Image.open(f"{rootdir}/prompt_{idx + 1}_sd3m.png")
sd3m_image.save(f"{rootdir}/prompt_{idx + 1}_sd3m.webp", "WEBP", quality=5)
sd3m_image.save(f"{rootdir}/prompt_{idx + 1}_sd3m.higher.webp", "WEBP", quality=30)
sdcascade_image = Image.open(f"{rootdir}/prompt_{idx + 1}_sdcascade.png")
sdcascade_image.save(f"{rootdir}/prompt_{idx + 1}_sdcascade.webp", "WEBP", quality=5)
sdcascade_image.save(f"{rootdir}/prompt_{idx + 1}_sdcascade.higher.webp", "WEBP", quality=30)
sdxl_image = Image.open(f"{rootdir}/prompt_{idx + 1}_sdxl.png")
sdxl_image.save(f"{rootdir}/prompt_{idx + 1}_sdxl.webp", "WEBP", quality=5)
sdxl_image.save(f"{rootdir}/prompt_{idx + 1}_sdxl.higher.webp", "WEBP", quality=30)
dalle3_image = Image.open(f"{rootdir}/prompt_{idx + 1}_dalle3.png")
dalle3_image.save(f"{rootdir}/prompt_{idx + 1}_dalle3.webp", "WEBP", quality=5)
dalle3_image.save(f"{rootdir}/prompt_{idx + 1}_dalle3.higher.webp", "WEBP", quality=30)
if __name__ == "__main__":
prompt_info = json.load(open("radical_image_descriptions.json"))
prompts = [p["description"] for p in prompt_info]
generate_and_save_images(prompts)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment