Last active
October 6, 2025 09:33
-
-
Save laurentperrinet/34a9429f0bd9cb653b871ad7bdc0558f to your computer and use it in GitHub Desktop.
This script generates images of many cats and one dog using stable diffusion.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # https://gist.github.com/laurentperrinet/34a9429f0bd9cb653b871ad7bdc0558f | |
| from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline | |
| import torch | |
| import argparse | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Generate images using Stable Diffusion XL") | |
| # Model parameters | |
| parser.add_argument("--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0", | |
| help="Base model path or identifier") | |
| parser.add_argument("--refiner_model", type=str, default="stabilityai/stable-diffusion-xl-refiner-1.0", | |
| help="Refiner model path or identifier") | |
| parser.add_argument("--torch_dtype", type=str, default="float16", choices=["float16", "float32"], | |
| help="Torch data type") | |
| parser.add_argument("--device", type=str, default="mps", choices=["mps", "cuda", "cpu"], | |
| help="Device to run inference on") | |
| # Generation parameters | |
| parser.add_argument("--prompt", type=str, default="one angry chihuahua dog face and many different cat faces tightly packed in a 13×8 grid. the integrated dog face at row 3 column 4. animals have diverse fur patterns (tabby, calico, siamese, persian, black, white, orange, gray), natural head positions, soft studio lighting, photorealistic 8K.", | |
| help="Prompt for image generation") | |
| parser.add_argument("--negative_prompt", type=str, default="irregular grid, multiple dogs, grid lines, bad alignment, obvious dog, poor blending, lighting mismatch, size difference, low quality, blurry, distorted, artifacts, watermarks, cartoon, painting", | |
| help="Negative prompt for image generation") | |
| parser.add_argument("--num_images", type=int, default=20, | |
| help="Number of images to generate") | |
| parser.add_argument("--base_num_inference_steps", type=int, default=40, | |
| help="Number of inference steps for base model") | |
| parser.add_argument("--refiner_num_inference_steps", type=int, default=40, | |
| help="Number of inference steps for refiner model") | |
| parser.add_argument("--guidance_scale", type=float, default=8.0, | |
| help="Guidance scale for generation") | |
| parser.add_argument("--height", type=int, default=1000, | |
| help="Height of generated image") | |
| parser.add_argument("--width", type=int, default=1600, | |
| help="Width of generated image") | |
| parser.add_argument("--denoising_start", type=float, default=0.8, | |
| help="Denoising start value for refiner") | |
| parser.add_argument("--output_prefix", type=str, default="hundered-cats-and-one-dog", | |
| help="Prefix for output filenames") | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| # Set torch dtype | |
| if args.torch_dtype == "float16": | |
| torch_dtype = torch.float16 | |
| else: | |
| torch_dtype = torch.float32 | |
| # Load base model | |
| print("Loading base model...") | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| args.base_model, | |
| torch_dtype=torch_dtype, | |
| variant="fp16" if torch_dtype == torch.float16 else None, | |
| use_safetensors=True | |
| ) | |
| pipe.to(args.device) | |
| # Load refiner | |
| print("Loading refiner model...") | |
| refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( | |
| args.refiner_model, | |
| torch_dtype=torch_dtype, | |
| use_safetensors=True, | |
| variant="fp16" if torch_dtype == torch.float16 else None, | |
| ) | |
| refiner.to(args.device) | |
| # Generate images | |
| from tqdm import tqdm | |
| for i in tqdm(range(args.num_images), desc="Generating images"): | |
| # Generate with base model | |
| image = pipe( | |
| args.prompt, | |
| negative_prompt=args.negative_prompt, | |
| num_inference_steps=args.base_num_inference_steps, | |
| guidance_scale=args.guidance_scale, | |
| height=args.height, | |
| width=args.width, | |
| ).images[0] | |
| # Refine with refiner | |
| image = refiner( | |
| prompt=args.prompt, | |
| negative_prompt=args.negative_prompt, | |
| image=image, | |
| num_inference_steps=args.refiner_num_inference_steps, | |
| denoising_start=args.denoising_start, | |
| ).images[0] | |
| # Save image | |
| filename = f"{args.output_prefix}_{i+1}.png" | |
| image.save(filename) | |
| print(f"Saved {filename}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment