Last active
April 18, 2024 10:49
-
-
Save cavit99/4927cdd6b67576db9a5c4ed6fcc1e43f to your computer and use it in GitHub Desktop.
Stability AI API call in python for SD3 + Creative Upscale
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
import uuid | |
import time | |
import random | |
api_key = "your-stability-api-key" | |
num_batches = 6 # Specify the number of images to generate per API | |
prompt = "Your prompt here" | |
negative_prompt = "optional negative prompt here" # Optional | |
aspect_ratio = "2:3" #16:9 1:1 21:9 2:3 3:2 4:5 5:4 9:16 9:21 | |
upscale_images = False # Set to False to skip upscaling | |
upscale_creativity = "0.25" # Specify the creativity level for upscaling | |
use_core = False # Set to True to generate images using both SD3 and Core, False for SD3 only | |
# Calculate total cost in USD | |
sd3_cost = num_batches * 0.0065 | |
core_cost = num_batches * 0.003 if use_core else 0 | |
upscale_cost = num_batches * 0.025 if upscale_images else 0 | |
total_cost = sd3_cost + core_cost + upscale_cost | |
# Check account balance in USD | |
response = requests.get('https://api.stability.ai/v1/user/balance', headers={ | |
"Authorization": f"Bearer {api_key}" | |
}) | |
if response.status_code != 200: | |
raise Exception("Non-200 response: " + str(response.text)) | |
payload = response.json() | |
account_balance_credits = payload.get('credits', 0) | |
account_balance_usd = account_balance_credits / 100 # Convert credits to USD | |
if account_balance_usd < total_cost: | |
print(f"Insufficient account balance. Available balance: ${account_balance_usd:.2f}") | |
exit() | |
else: | |
print(f"Available balance: ${account_balance_usd:.2f}") | |
# Prompt user for cost confirmation | |
print(f"Total cost for generating {num_batches * (2 if use_core else 1)} image(s): ${total_cost:.2f}") | |
print(f"- SD3 generation cost: ${sd3_cost:.2f}") | |
if use_core: | |
print(f"- Core generation cost: ${core_cost:.2f}") | |
if upscale_images: | |
print(f"- Upscaling cost: ${upscale_cost:.2f}") | |
user_input = input("Do you want to proceed? (y/n): ") | |
if user_input.lower() != 'y': | |
print("Operation cancelled by user.") | |
exit() | |
unique_id = str(uuid.uuid4()) # Generate a unique UUID for the batch | |
# Generate images using SD3 and optionally Core APIs | |
generated_images = [] | |
for i in range(num_batches): | |
seed = random.randint(1, 4294967293) # Generate a new random seed for each pair | |
apis = ["sd3"] | |
if use_core: | |
apis.append("core") | |
for api in apis: | |
if api == "sd3": | |
generate_endpoint = "https://api.stability.ai/v2beta/stable-image/generate/sd3" | |
else: | |
generate_endpoint = "https://api.stability.ai/v2beta/stable-image/generate/core" | |
response = requests.post( | |
generate_endpoint, | |
headers={ | |
"authorization": f"Bearer {api_key}", | |
"accept": "image/*" | |
}, | |
files={"none": ''}, | |
data={ | |
"seed": seed, | |
"prompt": prompt, | |
"negative_prompt": negative_prompt, | |
"aspect_ratio": aspect_ratio, | |
"output_format": "png", | |
}, | |
) | |
if response.status_code == 200: | |
filename = f"./{unique_id}_{api}_{seed}.png" | |
with open(filename, 'wb') as file: | |
file.write(response.content) | |
print(f"Image {i+1} generated using {api.upper()} with seed {seed} and saved as {filename}") | |
generated_images.append((filename, api)) | |
else: | |
error_response = response.json() | |
if 'name' in error_response and error_response['name'] == 'content_moderation': | |
print("The request was flagged by the content moderation system. Please modify the prompt and try again.") | |
exit() | |
else: | |
raise Exception(str(error_response)) | |
# Queue up upscale requests for SD3 images only | |
upscale_requests = [] | |
if upscale_images: | |
for filename, api in generated_images: | |
if api == "sd3": | |
upscale_response = requests.post( | |
f"https://api.stability.ai/v2beta/stable-image/upscale/creative", | |
headers={ | |
"authorization": f"Bearer {api_key}", | |
"accept": "image/png" # Request PNG format for the upscaled image | |
}, | |
files={ | |
"image": open(filename, "rb") | |
}, | |
data={ | |
"prompt": prompt, | |
"creativity": upscale_creativity | |
}, | |
) | |
if upscale_response.status_code == 200: | |
generation_id = upscale_response.json().get('id') | |
print(f"Upscaling request for {filename} submitted. Generation ID: {generation_id}") | |
upscale_requests.append((filename, generation_id)) | |
else: | |
raise Exception(str(upscale_response.json())) | |
# Fetch upscaled images | |
for filename, generation_id in upscale_requests: | |
while True: | |
result_response = requests.get( | |
f"https://api.stability.ai/v2beta/stable-image/upscale/creative/result/{generation_id}", | |
headers={ | |
'accept': "image/*", | |
'authorization': f"Bearer {api_key}" | |
}, | |
) | |
if result_response.status_code == 202: | |
# The upscaling request is still in progress | |
print(f"Upscaling for {filename} in progress. Waiting for 20 seconds...") | |
time.sleep(20) # Wait for 20 seconds before checking again | |
elif result_response.status_code == 200: | |
print(f"Upscaling for {filename} complete!") | |
upscaled_filename = filename.replace(".png", "_upscaled.png") | |
with open(upscaled_filename, 'wb') as file: | |
file.write(result_response.content) | |
print(f"Upscaled image saved as {upscaled_filename}") | |
break # Exit the loop and move to the next upscale request | |
else: | |
raise Exception(str(result_response.json())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment