Created
August 11, 2023 09:52
-
-
Save sayakpaul/516dc7158f01fb674ab4f3e968429ab2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "bae975ab-1261-4399-87e6-5d71d457e601", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import PIL\n", | |
"import requests\n", | |
"import torch\n", | |
"from diffusers import StableDiffusionXLInstructPix2PixPipeline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "37c55618-8606-4742-9f2c-d7c333e879ca", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"MODEL_ID = \"sayakpaul/sdxl-instructpix2pix\"\n", | |
"SEED = 0\n", | |
"pipe = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(\n", | |
" MODEL_ID, torch_dtype=torch.float16\n", | |
").to(\"cuda\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "9caa8902-77a0-4b02-921f-c8d17ddcfb4d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!wget -q https://huggingface.co/spaces/timbrooks/instruct-pix2pix/resolve/main/imgs/example.jpg\n", | |
"!wget -q https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "2f3ca300-6960-4cfa-a808-aac24df90bc0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import hashlib\n", | |
"\n", | |
"\n", | |
"def infer(\n", | |
" prompt: str,\n", | |
" image: PIL.Image.Image,\n", | |
" guidance_scale=5,\n", | |
" image_guidance_scale=2,\n", | |
" num_inference_steps=20,\n", | |
"):\n", | |
" \"\"\"Performs inference with the pipeline.\"\"\"\n", | |
" hash_image = hashlib.sha1(image.tobytes()).hexdigest()\n", | |
" filename = f\"{str(hash_image)}_gs@{guidance_scale}_igs@{image_guidance_scale}_steps@{num_inference_steps}.png\"\n", | |
" edited_image = pipe(\n", | |
" prompt=prompt,\n", | |
" image=image,\n", | |
" guidance_scale=guidance_scale,\n", | |
" image_guidance_scale=image_guidance_scale,\n", | |
" num_inference_steps=num_inference_steps,\n", | |
" generator=torch.manual_seed(SEED),\n", | |
" ).images[0]\n", | |
" edited_image.save(filename)\n", | |
" return hash_image" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "db779600-d9e8-42f6-9fca-57c409f44ef4", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from typing import List\n", | |
"\n", | |
"def run_bulk_experiments(\n", | |
" image: PIL.Image.Image,\n", | |
" edit_prompt: str,\n", | |
" guidance_scales: List[float],\n", | |
" image_guidance_scales: List[float],\n", | |
" steps: List[int],\n", | |
"):\n", | |
" \"\"\"Runs bulk experiments with the pipeline.\"\"\"\n", | |
" for gs in guidance_scales:\n", | |
" for igs in image_guidance_scales:\n", | |
" for steps_ in steps:\n", | |
" hash_image = infer(\n", | |
" edit_prompt,\n", | |
" image,\n", | |
" guidance_scale=gs,\n", | |
" image_guidance_scale=igs,\n", | |
" num_inference_steps=steps_,\n", | |
" )\n", | |
" return hash_image" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "bcc68fa3-88f5-469d-b093-3ed2ccb781c1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import wandb\n", | |
"import glob\n", | |
"\n", | |
"\n", | |
"def log_to_wandb(initial_image_path: str, edit_prompt: str, image_hex: str):\n", | |
" \"\"\"Bulk logs results to wandb.\"\"\"\n", | |
" wandb.init(\n", | |
" project=\"instructpix2pix-sdxl-results\",\n", | |
" config={\"model_id\": MODEL_ID, \"seed\": SEED},\n", | |
" )\n", | |
" table = wandb.Table(\n", | |
" columns=[\n", | |
" \"Initial Image\",\n", | |
" \"Prompt\",\n", | |
" \"Edited Image\",\n", | |
" \"Guidance Scale\",\n", | |
" \"Image Guidance Scale\",\n", | |
" \"Number of Steps\",\n", | |
" ]\n", | |
" )\n", | |
"\n", | |
" edited_images = sorted(glob.glob(f\"{image_hex}_*.png\"))\n", | |
" for edited_image in edited_images:\n", | |
" gs = float(edited_image.split(\"_\")[1].split(\"@\")[-1])\n", | |
" igs = float(edited_image.split(\"_\")[2].split(\"@\")[-1])\n", | |
" steps = int(edited_image.split(\"_\")[3].split(\"@\")[-1].split(\".\")[0])\n", | |
" table.add_data(\n", | |
" wandb.Image(initial_image_path),\n", | |
" edit_prompt,\n", | |
" wandb.Image(edited_image),\n", | |
" gs,\n", | |
" igs,\n", | |
" steps,\n", | |
" )\n", | |
" wandb.log({\"results\": table})\n", | |
" wandb.finish()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "0dca3588-4b71-4d38-b40c-27f0aca2c2ea", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"cyborg = PIL.Image.open(\"example.jpg\")\n", | |
"mountain = PIL.Image.open(\"mountain.png\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "2f44f8bd-7386-4675-8e98-2a682ddbacac", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"prompt = \"Turn him into a cyborg!\"\n", | |
"guidance_scales = [5, 7, 7.5]\n", | |
"image_guidance_scales = [1, 1.5, 2]\n", | |
"steps = [20, 25, 40, 50]\n", | |
"\n", | |
"hash_image = run_bulk_experiments(\n", | |
" image=cyborg,\n", | |
" edit_prompt=prompt,\n", | |
" guidance_scales=guidance_scales,\n", | |
" image_guidance_scales=image_guidance_scales,\n", | |
" steps=steps,\n", | |
")\n", | |
"hash_image" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "7838ed52-3690-46b4-9136-e8fbf26381a8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"log_to_wandb(initial_image_path=\"example.jpg\", edit_prompt=prompt, image_hex=hash_image)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "42fad38a-1a0f-46f0-aacc-212fe893c9c8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"prompt = \"make the mountains snowy\"\n", | |
"guidance_scales = [5, 7, 7.5]\n", | |
"image_guidance_scales = [1, 1.5, 2]\n", | |
"steps = [20, 25, 40, 50]\n", | |
"\n", | |
"hash_image = run_bulk_experiments(\n", | |
" image=mountain,\n", | |
" edit_prompt=prompt,\n", | |
" guidance_scales=guidance_scales,\n", | |
" image_guidance_scales=image_guidance_scales,\n", | |
" steps=steps,\n", | |
")\n", | |
"\n", | |
"log_to_wandb(\n", | |
" initial_image_path=\"mountain.png\", edit_prompt=prompt, image_hex=hash_image\n", | |
")" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment