Forked from inu-ai/dreambooth_stable_diffusion_mod.ipynb
Created
February 1, 2023 23:22
-
-
Save jochemstoel/f15419f36a7199f38a368268d1b7f587 to your computer and use it in GitHub Desktop.
dreambooth_stable_diffusion_mod.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/thx-pw/37ba770dfd2b15f119ec8032c3ae90a1/dreambooth_stable_diffusion_mod.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"cellView": "form", | |
"id": "XU7NuMAA2drw" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@markdown Check type of GPU and VRAM available.\n", | |
"!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "BzM7j0ZSc_9c" | |
}, | |
"source": [ | |
"https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "wnTMyW41cC1E" | |
}, | |
"source": [ | |
"## Install Requirements" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "aLWXPZqjsZVV" | |
}, | |
"outputs": [], | |
"source": [ | |
"!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/examples/dreambooth/train_dreambooth.py\n", | |
"!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py\n", | |
"%pip install -qq git+https://github.com/ShivamShrirao/diffusers\n", | |
"%pip install -q -U --pre triton\n", | |
"%pip install -q accelerate transformers ftfy bitsandbytes==0.35.0 gradio natsort" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"cellView": "form", | |
"id": "y4lqqWT_uxD2" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@title Login to HuggingFace 🤗\n", | |
"\n", | |
"#@markdown You need to accept the model license before downloading or using the Stable Diffusion weights. Please, visit the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree. You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work.\n", | |
"# https://huggingface.co/settings/tokens\n", | |
"!mkdir -p ~/.huggingface\n", | |
"HUGGINGFACE_TOKEN = \"\" #@param {type:\"string\"}\n", | |
"!echo -n \"{HUGGINGFACE_TOKEN}\" > ~/.huggingface/token" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "XfTlc8Mqb8iH" | |
}, | |
"source": [ | |
"### Install xformers from precompiled wheel." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "n6dcjPnnaiCn" | |
}, | |
"outputs": [], | |
"source": [ | |
"%pip install --no-deps -q https://github.com/brian6091/xformers-wheels/releases/download/0.0.15.dev0%2B4c06c79/xformers-0.0.15.dev0+4c06c79.d20221205-cp38-cp38-linux_x86_64.whl\n", | |
"# These were compiled on Tesla T4.\n", | |
"\n", | |
"# If precompiled wheels don't work, install it with the following command. It will take around 40 minutes to compile.\n", | |
"# %pip install git+https://github.com/facebookresearch/xformers@4c06c79#egg=xformers" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "G0NV324ZcL9L" | |
}, | |
"source": [ | |
"## Settings and run" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "Rxg0y5MBudmd", | |
"cellView": "form" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@markdown Name/Path of the initial model.\n", | |
"MODEL_NAME = \"hakurei/waifu-diffusion\" #@param [\"hakurei/waifu-diffusion\", \"runwayml/stable-diffusion-v1-5\", \"CompVis/stable-diffusion-v1-4\", \"naclbit/trinart_stable_diffusion_v2,diffusers-115k\", \"naclbit/trinart_stable_diffusion_v2,diffusers-95k\", \"naclbit/trinart_stable_diffusion_v2,diffusers-60k\"] {allow-input: true}\n", | |
"\n", | |
"INSTANCE_NAME = \"sks\" #@param {type:\"string\"}\n", | |
"CLASS_NAME = \"1boy\" #@param {type:\"string\"}\n", | |
"INSTANCE_PROMPT = f\"{INSTANCE_NAME} {CLASS_NAME}\"\n", | |
"\n", | |
"CLASS_DIR = f\"/content/data/{CLASS_NAME}\"\n", | |
"\n", | |
"INSTANCE_DIR = f\"/content/data/{INSTANCE_NAME}\"\n", | |
"!mkdir -p $INSTANCE_DIR\n", | |
"\n", | |
"OUTPUT_DIR = f\"/content/stable_diffusion_weights/{INSTANCE_NAME}\"\n", | |
"print(f\"[*] Weights will be saved at {OUTPUT_DIR}\")\n", | |
"!mkdir -p $OUTPUT_DIR" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"cellView": "form", | |
"id": "fe-GgtnUVO_e" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@markdown Upload your images by running this cell.\n", | |
"\n", | |
"#@markdown OR\n", | |
"\n", | |
"#@markdown You can use the file manager on the left panel to upload (drag and drop) to INSTANCE_DIR (it uploads faster)\n", | |
"\n", | |
"import os\n", | |
"from google.colab import files\n", | |
"import shutil\n", | |
"\n", | |
"uploaded = files.upload()\n", | |
"for filename in uploaded.keys():\n", | |
" dst_path = os.path.join(INSTANCE_DIR, filename)\n", | |
" shutil.move(filename, dst_path)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "qn5ILIyDJIcX" | |
}, | |
"source": [ | |
"# Start Training\n", | |
"\n", | |
"Use the table below to choose the best flags based on your memory and speed requirements. Tested on Tesla T4 GPU.\n", | |
"\n", | |
"\n", | |
"| `fp16` | `train_batch_size` | `gradient_accumulation_steps` | `gradient_checkpointing` | `use_8bit_adam` | GB VRAM usage | Speed (it/s) |\n", | |
"| ---- | ------------------ | ----------------------------- | ----------------------- | --------------- | ---------- | ------------ |\n", | |
"| fp16 | 1 | 1 | TRUE | TRUE | 9.92 | 0.93 |\n", | |
"| no | 1 | 1 | TRUE | TRUE | 10.08 | 0.42 |\n", | |
"| fp16 | 2 | 1 | TRUE | TRUE | 10.4 | 0.66 |\n", | |
"| fp16 | 1 | 1 | FALSE | TRUE | 11.17 | 1.14 |\n", | |
"| no | 1 | 1 | FALSE | TRUE | 11.17 | 0.49 |\n", | |
"| fp16 | 1 | 2 | TRUE | TRUE | 11.56 | 1 |\n", | |
"| fp16 | 2 | 1 | FALSE | TRUE | 13.67 | 0.82 |\n", | |
"| fp16 | 1 | 2 | FALSE | TRUE | 13.7 | 0.83 |\n", | |
"| fp16 | 1 | 1 | TRUE | FALSE | 15.79 | 0.77 |\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "-ioxxvHoicPs" | |
}, | |
"source": [ | |
"Add `--gradient_checkpointing` flag for around 9.92 GB VRAM usage.\n", | |
"\n", | |
"remove `--use_8bit_adam` flag for full precision. Requires 15.79 GB with `--gradient_checkpointing` else 17.8 GB." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "jjcSXTp-u-Eg" | |
}, | |
"outputs": [], | |
"source": [ | |
"!accelerate launch train_dreambooth.py \\\n", | |
" --pretrained_model_name_or_path=$MODEL_NAME \\\n", | |
" --pretrained_vae_name_or_path=\"\" \\\n", | |
" --instance_data_dir=$INSTANCE_DIR \\\n", | |
" --class_data_dir=$CLASS_DIR \\\n", | |
" --output_dir=$OUTPUT_DIR \\\n", | |
" --revision=\"fp16\" \\\n", | |
" --with_prior_preservation --prior_loss_weight=1.0 \\\n", | |
" --resolution=512 \\\n", | |
" --train_batch_size=4 \\\n", | |
" --train_text_encoder \\\n", | |
" --pad_tokens \\\n", | |
" --mixed_precision=\"fp16\" \\\n", | |
" --gradient_checkpointing \\\n", | |
" --use_8bit_adam \\\n", | |
" --gradient_accumulation_steps=1 \\\n", | |
" --learning_rate=1e-6 \\\n", | |
" --lr_scheduler=\"constant\" \\\n", | |
" --lr_warmup_steps=0 \\\n", | |
" --num_class_images=60 \\\n", | |
" --sample_batch_size=1 \\\n", | |
" --max_train_steps=300 \\\n", | |
" --instance_prompt=\"{INSTANCE_PROMPT}\" \\\n", | |
" --class_prompt=\"{CLASS_NAME}\"" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "5V8wgU0HN-Kq" | |
}, | |
"source": [ | |
"## Convert weights to ckpt to use in web UIs like AUTOMATIC1111." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"cellView": "form", | |
"id": "89Az5NUxOWdy" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@markdown Run conversion.\n", | |
"from natsort import natsorted\n", | |
"from glob import glob\n", | |
"import os\n", | |
"WEIGHTS_DIR = natsorted(glob(OUTPUT_DIR + os.sep + \"*\"))[-1]\n", | |
"print(f\"[*] WEIGHTS_DIR={WEIGHTS_DIR}\")\n", | |
"\n", | |
"ckpt_path = WEIGHTS_DIR + \"/model.ckpt\"\n", | |
"\n", | |
"half_arg = \"\"\n", | |
"#@markdown Whether to convert to fp16, takes half the space (2GB).\n", | |
"fp16 = True #@param {type: \"boolean\"}\n", | |
"if fp16:\n", | |
" half_arg = \"--half\"\n", | |
"!python convert_diffusers_to_original_stable_diffusion.py --model_path $WEIGHTS_DIR --checkpoint_path $ckpt_path $half_arg\n", | |
"print(f\"[*] Converted ckpt saved at {ckpt_path}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#@title Save ckpt to google drive\n", | |
"from google.colab import drive\n", | |
"drive.mount('/content/drive')\n", | |
"\n", | |
"import os\n", | |
"model_checkpoints = \"/content/drive/MyDrive/sd/stable-diffusion-webui/models/Stable-diffusion\"\n", | |
"os.makedirs(model_checkpoints, exist_ok=True)\n", | |
"!cp {ckpt_path} {model_checkpoints}" | |
], | |
"metadata": { | |
"cellView": "form", | |
"id": "MYDjfXf8MB2R" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "ToNG4fd_dTbF" | |
}, | |
"source": [ | |
"## Inference" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "gW15FjffdTID" | |
}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from torch import autocast\n", | |
"from diffusers import StableDiffusionPipeline, DDIMScheduler\n", | |
"from IPython.display import display\n", | |
"\n", | |
"model_path = WEIGHTS_DIR # If you want to use previously trained model saved in gdrive, replace this with the full path of model in gdrive\n", | |
"\n", | |
"scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False)\n", | |
"pipe = StableDiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16).to(\"cuda\")\n", | |
"\n", | |
"g_cuda = None" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"cellView": "form", | |
"id": "oIzkltjpVO_f" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@markdown Can set random seed here for reproducibility.\n", | |
"g_cuda = torch.Generator(device='cuda')\n", | |
"seed = 52362 #@param {type:\"number\"}\n", | |
"g_cuda.manual_seed(seed)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "K6xoHWSsbcS3", | |
"cellView": "form" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@title Run for generating images.\n", | |
"\n", | |
"prompt = \"sks 1boy in a bucket\" #@param {type:\"string\"}\n", | |
"negative_prompt = \"\" #@param {type:\"string\"}\n", | |
"num_samples = 4 #@param {type:\"number\"}\n", | |
"guidance_scale = 7.5 #@param {type:\"number\"}\n", | |
"num_inference_steps = 50 #@param {type:\"number\"}\n", | |
"height = 512 #@param {type:\"number\"}\n", | |
"width = 512 #@param {type:\"number\"}\n", | |
"\n", | |
"with autocast(\"cuda\"), torch.inference_mode():\n", | |
" images = pipe(\n", | |
" prompt,\n", | |
" height=height,\n", | |
" width=width,\n", | |
" negative_prompt=negative_prompt,\n", | |
" num_images_per_prompt=num_samples,\n", | |
" num_inference_steps=num_inference_steps,\n", | |
" guidance_scale=guidance_scale,\n", | |
" generator=g_cuda\n", | |
" ).images\n", | |
"\n", | |
"for img in images:\n", | |
" display(img)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"cellView": "form", | |
"id": "WMCqQ5Tcdsm2" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@markdown Run Gradio UI for generating images.\n", | |
"import gradio as gr\n", | |
"\n", | |
"def inference(prompt, negative_prompt, num_samples, height=512, width=512, num_inference_steps=50, guidance_scale=7.5):\n", | |
" with torch.autocast(\"cuda\"), torch.inference_mode():\n", | |
" return pipe(\n", | |
" prompt, height=int(height), width=int(width),\n", | |
" negative_prompt=negative_prompt,\n", | |
" num_images_per_prompt=int(num_samples),\n", | |
" num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,\n", | |
" generator=g_cuda\n", | |
" ).images\n", | |
"\n", | |
"with gr.Blocks() as demo:\n", | |
" with gr.Row():\n", | |
" with gr.Column():\n", | |
" prompt = gr.Textbox(label=\"Prompt\", value=\"photo of sks dog in a bucket\")\n", | |
" negative_prompt = gr.Textbox(label=\"Negative Prompt\", value=\"\")\n", | |
" run = gr.Button(value=\"Generate\")\n", | |
" with gr.Row():\n", | |
" num_samples = gr.Number(label=\"Number of Samples\", value=4)\n", | |
" guidance_scale = gr.Number(label=\"Guidance Scale\", value=7.5)\n", | |
" with gr.Row():\n", | |
" height = gr.Number(label=\"Height\", value=512)\n", | |
" width = gr.Number(label=\"Width\", value=512)\n", | |
" num_inference_steps = gr.Slider(label=\"Steps\", value=50)\n", | |
" with gr.Column():\n", | |
" gallery = gr.Gallery()\n", | |
"\n", | |
" run.click(inference, inputs=[prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery)\n", | |
"\n", | |
"demo.launch(debug=True)" | |
] | |
} | |
], | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"collapsed_sections": [ | |
"XfTlc8Mqb8iH" | |
], | |
"provenance": [], | |
"private_outputs": true, | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3.8.12 ('pytorch')", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.8.12" | |
}, | |
"vscode": { | |
"interpreter": { | |
"hash": "2d58e898dde0263bc564c6968b04150abacfd33eed9b19aaa8e45c040360e146" | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment