Created
November 11, 2023 19:39
-
-
Save jepler/39efb2797b3cbef5eecf6f858bf47f95 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import binascii | |
import hashlib | |
import io | |
import json | |
import re | |
import subprocess | |
import sys | |
import httpx | |
from PIL import Image | |
from PIL.PngImagePlugin import PngInfo | |
import click | |
from chap.key import get_key | |
unsafe_chars = re.compile(r"[^a-zA-Z0-9-_]+") | |
key = get_key("openai_api_key") | |
@click.command | |
@click.option( | |
"--keep", | |
is_flag=True, | |
help="Add the magic string to disable prompt elaboration in dall-e-3 (do not use with dall-e-2)", | |
) | |
@click.option( | |
"--model", | |
type=str, | |
default="dall-e-3", | |
help="dall-e-2 or dall-e-3 at the time of writing", | |
) | |
@click.option( | |
"--size", | |
type=str, | |
default="1024x1024", | |
help="256x256, 512x512, or 1024x1024 for dall-e-2, 1024x1024, 1024x1792, or 1792x1024 for dall-e-3", | |
) | |
@click.option( | |
"--style", type=str, default="vivid", help="vivid or natural (dall-e 3 only)" | |
) | |
@click.option( | |
"--quality", type=str, default="standard", help="standard of hd (dall-e 3 only)" | |
) | |
@click.option( | |
"--action", | |
type=str, | |
default=None, | |
help="Executable program to run on each image (e.g., open, firefox)", | |
) | |
@click.argument("qstr", nargs=-1, required=False) | |
def main(keep, model, size, style, quality, action, qstr=[]): | |
if qstr: | |
baseprompt = " ".join(qstr) | |
else: | |
baseprompt = input("Image description: ") | |
if keep and model != "dall-e-2": | |
prompt = ( | |
"I NEED to test how the tool works with extremely simple prompts. DO NOT add any detail, just use it AS-IS:" | |
+ baseprompt | |
) | |
else: | |
prompt = baseprompt | |
response = httpx.post( | |
"https://api.openai.com/v1/images/generations", | |
json={ | |
"model": model, | |
"prompt": prompt, | |
"size": size, | |
"quality": quality, | |
"style": style, | |
"response_format": "b64_json", | |
}, | |
headers={ | |
"Authorization": f"Bearer {key}", | |
}, | |
timeout=360, | |
) | |
if response.status_code != 200: | |
try: | |
j = response.json() | |
raise SystemExit(f"Failure {j['error']['message']} ({response.status_code})") | |
except (KeyError, IndexError, json.decoder.JSONDecodeError): | |
raise SystemExit(f"Failure {response.text} ({response.status_code})") | |
try: | |
j = response.json() | |
for row in j["data"]: | |
data = binascii.a2b_base64(row["b64_json"]) | |
hash = hashlib.sha256(data).hexdigest()[:8] | |
filename = f"{unsafe_chars.sub('-', baseprompt)[:96]}-{hash}.png" | |
image = Image.open(io.BytesIO(data)) | |
metadata = PngInfo() | |
metadata.add_text("prompt", baseprompt) | |
if (revised_prompt := row.get("revised_prompt")) is not None: | |
print(f"revised prompt: {revised_prompt}") | |
metadata.add_text("revised_prompt", revised_prompt) | |
print(f"Saving to {filename}") | |
image.save(filename, pnginfo=metadata) | |
if action: | |
subprocess.run([action, filename]) | |
except (KeyError, IndexError, json.decoder.JSONDecodeError): | |
raise SystemExit(f"Failure {response.text} ({response.status_code})") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment