Last active
November 24, 2025 00:16
-
-
Save willkurt/d10672f4aa7ea4dfe36f26c879deec8c to your computer and use it in GitHub Desktop.
AI Training Data Generator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Training Data Harvester | |
| Generates training data by using an LLM to prompt DALL-E 3 to create images | |
| Before running make sure to install the necessary pip packages by running: | |
| `pip install openai requests pydantic` | |
| You will also need to set your OPENAI_API_KEY environmental variable | |
| Then you can run this code as follows: | |
| `python ai_generate.py -n 200 -f train_data` | |
| Where `-n` specifies the number of images to generate | |
| and `-f` the location of the file to store them. | |
| """ | |
| import os | |
| import argparse | |
| import random | |
| from pathlib import Path | |
| from openai import OpenAI | |
| from openai import APIError, BadRequestError | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| from textwrap import dedent | |
| from pydantic import BaseModel | |
| # Subject and style lists for random selection | |
| SUBJECTS = [ | |
| "crab", | |
| "sand worm", | |
| "spider", | |
| "octopus", | |
| "shark", | |
| "snake", | |
| "bird", | |
| "bear", | |
| "wolf", | |
| "lion", | |
| "tiger", | |
| "elephant", | |
| "rhino", | |
| "scorpion", | |
| "beetle", | |
| "butterfly", | |
| "dragonfly", | |
| "mantis", | |
| "ant", | |
| "bee", | |
| ] | |
| STYLES = [ | |
| "Sci-fi film", | |
| "Horror film", | |
| "Film noir", | |
| "Western film", | |
| "Fantasy film", | |
| "Thriller film", | |
| "Action film", | |
| "Monster movie", | |
| "B-movie", | |
| "Classic horror", | |
| ] | |
| class PromptResponse(BaseModel): | |
| prompt: str | |
| def generate_prompt(client: OpenAI, subject: str, style: str) -> str: | |
| """Generate an image prompt using GPT-4o based on subject and style.""" | |
| system_prompt = dedent(""" | |
| You are a helpful assistant that generates prompts that will be used to create images. | |
| This prompts will be used to create images using Dall-E 3 | |
| These images will depict scenes of a certain subject, but as a monster. | |
| You will also be provided information of roughly what the style should, | |
| it will typically be some sort of film style (such as Sci-Fi or Horror). | |
| For example - | |
| subject: crab | |
| style: Sci-fi film | |
| prompt: A terrifying crab monster emerging from the ocean on a distant planet, the scene is from a 1960s sci-film | |
| """) | |
| completion = client.chat.completions.parse( | |
| model="gpt-4o-2024-08-06", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| { | |
| "role": "user", | |
| "content": dedent(f""" | |
| subject: {subject} | |
| style: {style} | |
| """), | |
| }, | |
| ], | |
| response_format=PromptResponse, | |
| ) | |
| result = completion.choices[0].message.parsed | |
| return result.prompt | |
| def generate_and_save_image( | |
| client: OpenAI, | |
| prompt: str, | |
| save_path: Path, | |
| size: str = "1024x1024", | |
| quality: str = "standard", | |
| ) -> Image.Image: | |
| """ | |
| Generate an image using OpenAI's DALL-E API and save it to a file. | |
| Args: | |
| client: OpenAI client instance | |
| prompt: The text description of the image to generate | |
| save_path: Path to save the image | |
| size: Image size - options: "1024x1024", "1792x1024", or "1024x1792" | |
| quality: Image quality - "standard" or "hd" | |
| Returns: | |
| PIL.Image: The generated image | |
| """ | |
| print(f"Generating image for prompt: '{prompt}'...") | |
| # Generate the image | |
| response = client.images.generate( | |
| model="dall-e-3", | |
| prompt=prompt, | |
| size=size, | |
| quality=quality, | |
| n=1, | |
| ) | |
| # Get the image URL | |
| image_url = response.data[0].url | |
| print(f"Image generated successfully!") | |
| # Download the image | |
| image_response = requests.get(image_url) | |
| image = Image.open(BytesIO(image_response.content)) | |
| # Save the image | |
| image.save(save_path) | |
| print(f"Image saved to: {save_path}") | |
| # Save the prompt to a text file with the same name | |
| prompt_path = save_path.with_suffix(".txt") | |
| with open(prompt_path, "w") as f: | |
| f.write(prompt) | |
| print(f"Prompt saved to: {prompt_path}") | |
| return image | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Generate training data images using DALL-E 3" | |
| ) | |
| parser.add_argument( | |
| "-n", | |
| type=int, | |
| default=5, | |
| help="Number of images to generate (default: 5)", | |
| ) | |
| parser.add_argument( | |
| "-f", | |
| "--folder", | |
| type=str, | |
| default="images", | |
| help="Folder to save images (default: images)", | |
| ) | |
| args = parser.parse_args() | |
| # Initialize OpenAI client | |
| client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) | |
| if not client.api_key: | |
| print("Error: OPENAI_API_KEY environment variable not set") | |
| print("Please set it with: export OPENAI_API_KEY='your-api-key-here'") | |
| return 1 | |
| # Create output folder if it doesn't exist | |
| output_folder = Path(args.folder) | |
| output_folder.mkdir(parents=True, exist_ok=True) | |
| # Find the next image number by checking existing files | |
| existing_images = list(output_folder.glob("image_*.png")) | |
| start_num = 0 | |
| if existing_images: | |
| # Extract numbers from filenames (format: image_XXXX.png) | |
| numbers = [] | |
| for f in existing_images: | |
| # Get the stem (filename without extension) and split by underscore | |
| parts = f.stem.split("_") | |
| if len(parts) == 2 and parts[0] == "image" and parts[1].isdigit(): | |
| numbers.append(int(parts[1])) | |
| if numbers: | |
| start_num = max(numbers) + 1 | |
| print(f"Found {len(existing_images)} existing image(s), continuing from image_{start_num:04d}.png") | |
| else: | |
| print("Found existing files but couldn't parse numbers, starting from image_0000.png") | |
| else: | |
| print("No existing images found, starting from image_0000.png") | |
| print(f"Generating {args.n} images...") | |
| print(f"Output folder: {output_folder}") | |
| print() | |
| # Generate images - continue until we have the requested number of successful generations | |
| successful_count = 0 | |
| image_num = start_num | |
| attempt_count = 0 | |
| while successful_count < args.n: | |
| attempt_count += 1 | |
| # Randomly select subject and style | |
| subject = random.choice(SUBJECTS) | |
| style = random.choice(STYLES) | |
| print(f"[Attempt {attempt_count}] Generating image {image_num:04d} ({successful_count + 1}/{args.n} successful)") | |
| print(f" Subject: {subject}") | |
| print(f" Style: {style}") | |
| try: | |
| # Generate prompt using LLM | |
| prompt = generate_prompt(client, subject, style) | |
| # Generate and save image | |
| image_path = output_folder / f"image_{image_num:04d}.png" | |
| generate_and_save_image(client, prompt, image_path) | |
| print(f" ✓ Completed image {image_num:04d}") | |
| print() | |
| # Successfully generated - increment counters | |
| successful_count += 1 | |
| image_num += 1 | |
| except (APIError, BadRequestError) as e: | |
| # Check if it's a content policy violation | |
| error_message = str(e).lower() | |
| error_code = getattr(e, 'code', None) | |
| error_type = getattr(e, 'type', None) | |
| # Check for content policy violations in various ways | |
| is_content_violation = ( | |
| "content" in error_message and ("policy" in error_message or "violation" in error_message or "safety" in error_message) | |
| or error_code == "content_policy_violation" | |
| or error_type == "content_policy_violation" | |
| or "content_policy" in error_message | |
| ) | |
| if is_content_violation: | |
| print(f" ⚠ Content policy violation - skipping this attempt (not counted)") | |
| print(f" Error: {e}") | |
| print() | |
| # Don't increment successful_count, but also don't increment image_num | |
| # Try again with a new random subject/style | |
| continue | |
| else: | |
| # Other API errors - still skip but log | |
| print(f" ✗ API Error generating image {image_num:04d}: {e}") | |
| print() | |
| continue | |
| except Exception as e: | |
| # Other unexpected errors | |
| print(f" ✗ Error generating image {image_num:04d}: {e}") | |
| print() | |
| continue | |
| print(f"Done! Successfully generated {successful_count} images in {output_folder} (after {attempt_count} attempts)") | |
| return 0 | |
| if __name__ == "__main__": | |
| exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment