-
-
Save rohithreddy/badf8ffd239b7effe1084c0847f70106 to your computer and use it in GitHub Desktop.
vibecheckgen.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
import json | |
import base64 | |
from PIL import Image | |
from io import BytesIO | |
import os | |
import openai | |
from openai import OpenAI | |
import fal_client | |
# Constants | |
BASE_URL = "http://localhost:8000" | |
STABILITY_API_KEY = # | |
openai.api_key = # | |
def generate_auraflow(prompt, steps=50, scale=7.0, height=1024, width=1024, | |
use_self_as_nullcond=False, use_student=False): | |
url = f"{BASE_URL}/generate" | |
payload = { | |
"prompt": prompt, | |
"steps": steps, | |
"scale": scale, | |
"height": height, | |
"width": width, | |
"use_self_as_nullcond": use_self_as_nullcond, | |
"use_student": use_student | |
} | |
response = requests.post(url, json=payload) | |
if response.status_code == 200: | |
result = response.json() | |
images = [] | |
for img_str in result['images']: | |
img = Image.open(BytesIO(base64.b64decode(img_str))) | |
images.append(img) | |
return images[0] # Return only the first image | |
else: | |
print(f"AuraFlow Error: {response.status_code}") | |
print(response.text) | |
return None | |
def generate_sd3l(prompt): | |
response = requests.post( | |
f"https://api.stability.ai/v2beta/stable-image/generate/sd3", | |
headers={ | |
"authorization": f"Bearer {STABILITY_API_KEY}", | |
"accept": "image/*" | |
}, | |
files={"none": ''}, | |
data={ | |
"prompt": prompt, | |
"mode": "text-to-image", | |
"aspect_ratio": "1:1", | |
"output_format": "png", | |
"model": "sd3", | |
}, | |
) | |
if response.status_code == 200: | |
return Image.open(BytesIO(response.content)) | |
else: | |
print(f"SD3-L Error: {response.status_code}") | |
print(response.text) | |
return None | |
def generate_sd3m(prompt, model_name = "fal-ai/stable-diffusion-v3-medium"): | |
try: | |
handler = fal_client.submit( | |
model_name, | |
arguments={ | |
"prompt": prompt, | |
"num_inference_steps": 50, | |
}, | |
) | |
result = handler.get() | |
result_url = result["images"][0]['url'] | |
response = requests.get(result_url) | |
return Image.open(BytesIO(response.content)) | |
except Exception as e: | |
print(f"SD3-M Error: {str(e)}") | |
return None | |
def generate_cascade(prompt, model_name): | |
try: | |
handler = fal_client.submit( | |
model_name, | |
arguments={ | |
"prompt": prompt, | |
}, | |
) | |
result = handler.get() | |
result_url = result["images"][0]['url'] | |
response = requests.get(result_url) | |
return Image.open(BytesIO(response.content)) | |
except Exception as e: | |
print(f"SD3-M Error: {str(e)}") | |
return None | |
def generate_sdxl(prompt, model_name): | |
try: | |
handler = fal_client.submit( | |
"fal-ai/fast-sdxl", | |
arguments={ | |
"prompt": prompt, | |
"num_inference_steps": 50 | |
}, | |
) | |
result = handler.get() | |
result_url = result["images"][0]['url'] | |
response = requests.get(result_url) | |
return Image.open(BytesIO(response.content)) | |
except Exception as e: | |
print(f"SD3-M Error: {str(e)}") | |
return None | |
def generate_dalle3(prompt): | |
client = OpenAI(api_key=openai.api_key) | |
response = client.images.generate( | |
model="dall-e-3", | |
prompt=prompt, | |
size="1024x1024", | |
quality="standard", | |
n=1, | |
) | |
image_url = response.data[0].url | |
response = requests.get(image_url) | |
return Image.open(BytesIO(response.content)) | |
def save_image(image, filename): | |
if image: | |
image.save(filename) | |
print(f"Saved: {filename}") | |
else: | |
print(f"Failed to save: {filename}") | |
def generate_and_save_images(prompts): | |
os.makedirs("outputs", exist_ok=True) | |
rootdir = "/home/ubuntu/geneval/image-gallery/public/gallery" | |
image_data = {} | |
for idx, prompt in enumerate(prompts): | |
print(f"Generating images for prompt {idx + 1}: {prompt}") | |
# AuraFlow | |
auraflow_image = generate_auraflow(prompt) | |
save_image(auraflow_image, f"{rootdir}/prompt_{idx + 1}_auraflow.png") | |
# SD3-L | |
sd3l_image = generate_sd3l(prompt) | |
save_image(sd3l_image, f"{rootdir}/prompt_{idx + 1}_sd3l.png") | |
# SD3-M | |
sd3m_image = generate_sd3m(prompt, model_name = "fal-ai/stable-diffusion-v3-medium") | |
save_image(sd3m_image, f"{rootdir}/prompt_{idx + 1}_sd3m.png") | |
sdcascade = generate_cascade(prompt, model_name = "fal-ai/stable-cascade") | |
save_image(sdcascade, f"{rootdir}/prompt_{idx + 1}_sdcascade.png") | |
# sdxl | |
sdxl_image = generate_sdxl(prompt, model_name = "fal-ai/fast-sdxl") | |
save_image(sdxl_image, f"{rootdir}/prompt_{idx + 1}_sdxl.png") | |
# dalle-3 | |
dalle3_image = generate_dalle3(prompt) | |
save_image(dalle3_image, f"{rootdir}/prompt_{idx + 1}_dalle3.png") | |
print(f"Completed prompt {idx + 1}\n") | |
image_data[f"gallery/prompt_{idx+1}_auraflow"] = { | |
"description": prompt, | |
"model": "auraflow" | |
} | |
image_data[f"gallery/prompt_{idx+1}_sd3l"] = { | |
"description": prompt, | |
"model": "sd3l" | |
} | |
image_data[f"gallery/prompt_{idx+1}_sd3m"] = { | |
"description": prompt, | |
"model": "sd3m" | |
} | |
image_data[f"gallery/prompt_{idx+1}_sdcascade"] = { | |
"description": prompt, | |
"model": "sdcascade" | |
} | |
image_data[f"gallery/prompt_{idx+1}_sdxl"] = { | |
"description": prompt, | |
"model": "sdxl" | |
} | |
image_data[f"gallery/prompt_{idx+1}_dalle3"] = { | |
"description": prompt, | |
"model": "dalle3" | |
} | |
with open(f"{rootdir}/image_data.json", "w") as f: | |
json.dump(image_data, f) | |
# load all the images and save webp, low-res format. | |
for idx, prompt in enumerate(prompts): | |
auraflow_image = Image.open(f"{rootdir}/prompt_{idx + 1}_auraflow.png") | |
auraflow_image.save(f"{rootdir}/prompt_{idx + 1}_auraflow.webp", "WEBP", quality=5) | |
auraflow_image.save(f"{rootdir}/prompt_{idx + 1}_auraflow.higher.webp", "WEBP", quality=30) | |
sd3l_image = Image.open(f"{rootdir}/prompt_{idx + 1}_sd3l.png") | |
sd3l_image.save(f"{rootdir}/prompt_{idx + 1}_sd3l.webp", "WEBP", quality=5) | |
sd3l_image.save(f"{rootdir}/prompt_{idx + 1}_sd3l.higher.webp", "WEBP", quality=30) | |
sd3m_image = Image.open(f"{rootdir}/prompt_{idx + 1}_sd3m.png") | |
sd3m_image.save(f"{rootdir}/prompt_{idx + 1}_sd3m.webp", "WEBP", quality=5) | |
sd3m_image.save(f"{rootdir}/prompt_{idx + 1}_sd3m.higher.webp", "WEBP", quality=30) | |
sdcascade_image = Image.open(f"{rootdir}/prompt_{idx + 1}_sdcascade.png") | |
sdcascade_image.save(f"{rootdir}/prompt_{idx + 1}_sdcascade.webp", "WEBP", quality=5) | |
sdcascade_image.save(f"{rootdir}/prompt_{idx + 1}_sdcascade.higher.webp", "WEBP", quality=30) | |
sdxl_image = Image.open(f"{rootdir}/prompt_{idx + 1}_sdxl.png") | |
sdxl_image.save(f"{rootdir}/prompt_{idx + 1}_sdxl.webp", "WEBP", quality=5) | |
sdxl_image.save(f"{rootdir}/prompt_{idx + 1}_sdxl.higher.webp", "WEBP", quality=30) | |
dalle3_image = Image.open(f"{rootdir}/prompt_{idx + 1}_dalle3.png") | |
dalle3_image.save(f"{rootdir}/prompt_{idx + 1}_dalle3.webp", "WEBP", quality=5) | |
dalle3_image.save(f"{rootdir}/prompt_{idx + 1}_dalle3.higher.webp", "WEBP", quality=30) | |
if __name__ == "__main__": | |
prompt_info = json.load(open("radical_image_descriptions.json")) | |
prompts = [p["description"] for p in prompt_info] | |
generate_and_save_images(prompts) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment