Skip to content

Instantly share code, notes, and snippets.

@laurentperrinet
Last active October 6, 2025 09:33
Show Gist options
  • Save laurentperrinet/34a9429f0bd9cb653b871ad7bdc0558f to your computer and use it in GitHub Desktop.
Save laurentperrinet/34a9429f0bd9cb653b871ad7bdc0558f to your computer and use it in GitHub Desktop.
This script generates images of many cats and one dog using stable diffusion.
# https://gist.github.com/laurentperrinet/34a9429f0bd9cb653b871ad7bdc0558f
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
import torch
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Generate images using Stable Diffusion XL")
# Model parameters
parser.add_argument("--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0",
help="Base model path or identifier")
parser.add_argument("--refiner_model", type=str, default="stabilityai/stable-diffusion-xl-refiner-1.0",
help="Refiner model path or identifier")
parser.add_argument("--torch_dtype", type=str, default="float16", choices=["float16", "float32"],
help="Torch data type")
parser.add_argument("--device", type=str, default="mps", choices=["mps", "cuda", "cpu"],
help="Device to run inference on")
# Generation parameters
parser.add_argument("--prompt", type=str, default="one angry chihuahua dog face and many different cat faces tightly packed in a 13×8 grid. the integrated dog face at row 3 column 4. animals have diverse fur patterns (tabby, calico, siamese, persian, black, white, orange, gray), natural head positions, soft studio lighting, photorealistic 8K.",
help="Prompt for image generation")
parser.add_argument("--negative_prompt", type=str, default="irregular grid, multiple dogs, grid lines, bad alignment, obvious dog, poor blending, lighting mismatch, size difference, low quality, blurry, distorted, artifacts, watermarks, cartoon, painting",
help="Negative prompt for image generation")
parser.add_argument("--num_images", type=int, default=20,
help="Number of images to generate")
parser.add_argument("--base_num_inference_steps", type=int, default=40,
help="Number of inference steps for base model")
parser.add_argument("--refiner_num_inference_steps", type=int, default=40,
help="Number of inference steps for refiner model")
parser.add_argument("--guidance_scale", type=float, default=8.0,
help="Guidance scale for generation")
parser.add_argument("--height", type=int, default=1000,
help="Height of generated image")
parser.add_argument("--width", type=int, default=1600,
help="Width of generated image")
parser.add_argument("--denoising_start", type=float, default=0.8,
help="Denoising start value for refiner")
parser.add_argument("--output_prefix", type=str, default="hundered-cats-and-one-dog",
help="Prefix for output filenames")
return parser.parse_args()
def main():
args = parse_args()
# Set torch dtype
if args.torch_dtype == "float16":
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
# Load base model
print("Loading base model...")
pipe = StableDiffusionXLPipeline.from_pretrained(
args.base_model,
torch_dtype=torch_dtype,
variant="fp16" if torch_dtype == torch.float16 else None,
use_safetensors=True
)
pipe.to(args.device)
# Load refiner
print("Loading refiner model...")
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
args.refiner_model,
torch_dtype=torch_dtype,
use_safetensors=True,
variant="fp16" if torch_dtype == torch.float16 else None,
)
refiner.to(args.device)
# Generate images
from tqdm import tqdm
for i in tqdm(range(args.num_images), desc="Generating images"):
# Generate with base model
image = pipe(
args.prompt,
negative_prompt=args.negative_prompt,
num_inference_steps=args.base_num_inference_steps,
guidance_scale=args.guidance_scale,
height=args.height,
width=args.width,
).images[0]
# Refine with refiner
image = refiner(
prompt=args.prompt,
negative_prompt=args.negative_prompt,
image=image,
num_inference_steps=args.refiner_num_inference_steps,
denoising_start=args.denoising_start,
).images[0]
# Save image
filename = f"{args.output_prefix}_{i+1}.png"
image.save(filename)
print(f"Saved {filename}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment