Skip to content

Instantly share code, notes, and snippets.

@jepler
Created November 11, 2023 19:39
Show Gist options
  • Save jepler/39efb2797b3cbef5eecf6f858bf47f95 to your computer and use it in GitHub Desktop.
Save jepler/39efb2797b3cbef5eecf6f858bf47f95 to your computer and use it in GitHub Desktop.
#!/usr/bin/python
import binascii
import hashlib
import io
import json
import re
import subprocess
import sys
import httpx
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import click
from chap.key import get_key
unsafe_chars = re.compile(r"[^a-zA-Z0-9-_]+")
key = get_key("openai_api_key")
@click.command
@click.option(
"--keep",
is_flag=True,
help="Add the magic string to disable prompt elaboration in dall-e-3 (do not use with dall-e-2)",
)
@click.option(
"--model",
type=str,
default="dall-e-3",
help="dall-e-2 or dall-e-3 at the time of writing",
)
@click.option(
"--size",
type=str,
default="1024x1024",
help="256x256, 512x512, or 1024x1024 for dall-e-2, 1024x1024, 1024x1792, or 1792x1024 for dall-e-3",
)
@click.option(
"--style", type=str, default="vivid", help="vivid or natural (dall-e 3 only)"
)
@click.option(
"--quality", type=str, default="standard", help="standard of hd (dall-e 3 only)"
)
@click.option(
"--action",
type=str,
default=None,
help="Executable program to run on each image (e.g., open, firefox)",
)
@click.argument("qstr", nargs=-1, required=False)
def main(keep, model, size, style, quality, action, qstr=[]):
if qstr:
baseprompt = " ".join(qstr)
else:
baseprompt = input("Image description: ")
if keep and model != "dall-e-2":
prompt = (
"I NEED to test how the tool works with extremely simple prompts. DO NOT add any detail, just use it AS-IS:"
+ baseprompt
)
else:
prompt = baseprompt
response = httpx.post(
"https://api.openai.com/v1/images/generations",
json={
"model": model,
"prompt": prompt,
"size": size,
"quality": quality,
"style": style,
"response_format": "b64_json",
},
headers={
"Authorization": f"Bearer {key}",
},
timeout=360,
)
if response.status_code != 200:
try:
j = response.json()
raise SystemExit(f"Failure {j['error']['message']} ({response.status_code})")
except (KeyError, IndexError, json.decoder.JSONDecodeError):
raise SystemExit(f"Failure {response.text} ({response.status_code})")
try:
j = response.json()
for row in j["data"]:
data = binascii.a2b_base64(row["b64_json"])
hash = hashlib.sha256(data).hexdigest()[:8]
filename = f"{unsafe_chars.sub('-', baseprompt)[:96]}-{hash}.png"
image = Image.open(io.BytesIO(data))
metadata = PngInfo()
metadata.add_text("prompt", baseprompt)
if (revised_prompt := row.get("revised_prompt")) is not None:
print(f"revised prompt: {revised_prompt}")
metadata.add_text("revised_prompt", revised_prompt)
print(f"Saving to {filename}")
image.save(filename, pnginfo=metadata)
if action:
subprocess.run([action, filename])
except (KeyError, IndexError, json.decoder.JSONDecodeError):
raise SystemExit(f"Failure {response.text} ({response.status_code})")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment