Created
April 1, 2024 21:34
-
-
Save lmarcondes/ebb3b2505fb5826ac4979fa0a3c04359 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/env python | |
from datetime import datetime | |
from openai import OpenAI | |
from argparse import ArgumentParser | |
import json | |
from pathlib import Path | |
from base64 import b64decode | |
from openai.types.image import Image | |
import requests | |
from openai.types.images_response import ImagesResponse | |
def get_arg_parser() -> ArgumentParser: | |
parser = ArgumentParser() | |
parser.add_argument( | |
"--prompt", type=str, required=True, help="Prompt to generate image" | |
) | |
parser.add_argument("--format", type=str, default="url", help="Response format") | |
parser.add_argument("--n", type=int, default=1, help="Number of images to generate") | |
parser.add_argument("--model", type=str, default="dall-e-3", help="Model to use") | |
return parser | |
def write_json(data: dict, filename: str | Path) -> None: | |
"""Write the data to a JSON file. | |
Args: | |
data (dict): The data to write to the JSON file. | |
filename (str): The name of the file to write the JSON data to. | |
""" | |
with open(filename, "w") as f: | |
json.dump(data, f) | |
def save_image(data: bytes, filename: str | Path) -> None: | |
"""Save the image data to a file. | |
Args: | |
data (str): The image data to save. | |
filename (str): The name of the file to save the image data to. | |
""" | |
with open(filename, "wb") as f: | |
f.write(data) | |
def save_image_b64(data: str, filename: str | Path) -> None: | |
"""Save the base64 encoded image data to a file. | |
Args: | |
data (str): The base64 encoded image data to save. | |
filename (str): The name of the file to save the image data to. | |
""" | |
save_image(b64decode(data), filename) | |
def save_image_from_url(url: str, filename: str | Path) -> None: | |
"""Save the image at the given URL to a file. | |
Args: | |
url (str): The URL of the image to save. | |
filename (str): The name of the file to save the image data to. | |
""" | |
response = requests.get(url) | |
response.raise_for_status() | |
save_image(response.content, filename) | |
def save_response(response: ImagesResponse) -> None: | |
"""Save the response to a JSON file and download the image. | |
Both are placed in the img directory, under a directory named | |
after the response id and the date. | |
Args: | |
response (ImagesResponse): The response to save. | |
""" | |
response_parsed = json.loads(response.model_dump_json()) | |
date = datetime.fromtimestamp(response.created).isoformat() | |
response_dir = Path("img") / date | |
response_dir.mkdir(exist_ok=True, parents=True) | |
write_json(response_parsed, response_dir / "response.json") | |
for i, value in enumerate(response.data): | |
base_dir = response_dir / f"image_{i}" | |
save_model_data(value, base_dir) | |
def save_model_data(data: Image, base_dir: Path) -> None: | |
image_url = data.url | |
b64_json_data = data.b64_json | |
base_dir.mkdir(exist_ok=True, parents=True) | |
if image_url is not None: | |
save_image_from_url(image_url, base_dir / "image.png") | |
elif b64_json_data is not None: | |
save_image_b64(b64_json_data, base_dir / "image.png") | |
else: | |
raise ValueError("No image data in response") | |
def run() -> None: | |
"""Runs the main program to take arguments from the command line and | |
send a request using the first argument to the OpenAI API. | |
The image is generated using the DALL-E 3 model. | |
Usage: python get-image-dalle3.py <prompt> | |
""" | |
client = OpenAI() | |
parser = get_arg_parser() | |
args, _ = parser.parse_known_args() | |
print(f"Prompt: {args.prompt}") | |
response = client.images.generate( | |
model=args.model, | |
prompt=args.prompt, | |
size="1024x1024", | |
quality="standard", | |
response_format=args.format, | |
n=args.n, | |
) | |
save_response(response) | |
if __name__ == "__main__": | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment