Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save jochemstoel/f15419f36a7199f38a368268d1b7f587 to your computer and use it in GitHub Desktop.
Save jochemstoel/f15419f36a7199f38a368268d1b7f587 to your computer and use it in GitHub Desktop.
dreambooth_stable_diffusion_mod.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/thx-pw/37ba770dfd2b15f119ec8032c3ae90a1/dreambooth_stable_diffusion_mod.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "XU7NuMAA2drw"
},
"outputs": [],
"source": [
"#@markdown Check type of GPU and VRAM available.\n",
"!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BzM7j0ZSc_9c"
},
"source": [
"https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wnTMyW41cC1E"
},
"source": [
"## Install Requirements"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aLWXPZqjsZVV"
},
"outputs": [],
"source": [
"!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/examples/dreambooth/train_dreambooth.py\n",
"!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py\n",
"%pip install -qq git+https://github.com/ShivamShrirao/diffusers\n",
"%pip install -q -U --pre triton\n",
"%pip install -q accelerate transformers ftfy bitsandbytes==0.35.0 gradio natsort"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "y4lqqWT_uxD2"
},
"outputs": [],
"source": [
"#@title Login to HuggingFace 🤗\n",
"\n",
"#@markdown You need to accept the model license before downloading or using the Stable Diffusion weights. Please, visit the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree. You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work.\n",
"# https://huggingface.co/settings/tokens\n",
"!mkdir -p ~/.huggingface\n",
"HUGGINGFACE_TOKEN = \"\" #@param {type:\"string\"}\n",
"!echo -n \"{HUGGINGFACE_TOKEN}\" > ~/.huggingface/token"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XfTlc8Mqb8iH"
},
"source": [
"### Install xformers from precompiled wheel."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "n6dcjPnnaiCn"
},
"outputs": [],
"source": [
"%pip install --no-deps -q https://github.com/brian6091/xformers-wheels/releases/download/0.0.15.dev0%2B4c06c79/xformers-0.0.15.dev0+4c06c79.d20221205-cp38-cp38-linux_x86_64.whl\n",
"# These were compiled on Tesla T4.\n",
"\n",
"# If precompiled wheels don't work, install it with the following command. It will take around 40 minutes to compile.\n",
"# %pip install git+https://github.com/facebookresearch/xformers@4c06c79#egg=xformers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G0NV324ZcL9L"
},
"source": [
"## Settings and run"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Rxg0y5MBudmd",
"cellView": "form"
},
"outputs": [],
"source": [
"#@markdown Name/Path of the initial model.\n",
"MODEL_NAME = \"hakurei/waifu-diffusion\" #@param [\"hakurei/waifu-diffusion\", \"runwayml/stable-diffusion-v1-5\", \"CompVis/stable-diffusion-v1-4\", \"naclbit/trinart_stable_diffusion_v2,diffusers-115k\", \"naclbit/trinart_stable_diffusion_v2,diffusers-95k\", \"naclbit/trinart_stable_diffusion_v2,diffusers-60k\"] {allow-input: true}\n",
"\n",
"INSTANCE_NAME = \"sks\" #@param {type:\"string\"}\n",
"CLASS_NAME = \"1boy\" #@param {type:\"string\"}\n",
"INSTANCE_PROMPT = f\"{INSTANCE_NAME} {CLASS_NAME}\"\n",
"\n",
"CLASS_DIR = f\"/content/data/{CLASS_NAME}\"\n",
"\n",
"INSTANCE_DIR = f\"/content/data/{INSTANCE_NAME}\"\n",
"!mkdir -p $INSTANCE_DIR\n",
"\n",
"OUTPUT_DIR = f\"/content/stable_diffusion_weights/{INSTANCE_NAME}\"\n",
"print(f\"[*] Weights will be saved at {OUTPUT_DIR}\")\n",
"!mkdir -p $OUTPUT_DIR"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "fe-GgtnUVO_e"
},
"outputs": [],
"source": [
"#@markdown Upload your images by running this cell.\n",
"\n",
"#@markdown OR\n",
"\n",
"#@markdown You can use the file manager on the left panel to upload (drag and drop) to INSTANCE_DIR (it uploads faster)\n",
"\n",
"import os\n",
"from google.colab import files\n",
"import shutil\n",
"\n",
"uploaded = files.upload()\n",
"for filename in uploaded.keys():\n",
" dst_path = os.path.join(INSTANCE_DIR, filename)\n",
" shutil.move(filename, dst_path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qn5ILIyDJIcX"
},
"source": [
"# Start Training\n",
"\n",
"Use the table below to choose the best flags based on your memory and speed requirements. Tested on Tesla T4 GPU.\n",
"\n",
"\n",
"| `fp16` | `train_batch_size` | `gradient_accumulation_steps` | `gradient_checkpointing` | `use_8bit_adam` | GB VRAM usage | Speed (it/s) |\n",
"| ---- | ------------------ | ----------------------------- | ----------------------- | --------------- | ---------- | ------------ |\n",
"| fp16 | 1 | 1 | TRUE | TRUE | 9.92 | 0.93 |\n",
"| no | 1 | 1 | TRUE | TRUE | 10.08 | 0.42 |\n",
"| fp16 | 2 | 1 | TRUE | TRUE | 10.4 | 0.66 |\n",
"| fp16 | 1 | 1 | FALSE | TRUE | 11.17 | 1.14 |\n",
"| no | 1 | 1 | FALSE | TRUE | 11.17 | 0.49 |\n",
"| fp16 | 1 | 2 | TRUE | TRUE | 11.56 | 1 |\n",
"| fp16 | 2 | 1 | FALSE | TRUE | 13.67 | 0.82 |\n",
"| fp16 | 1 | 2 | FALSE | TRUE | 13.7 | 0.83 |\n",
"| fp16 | 1 | 1 | TRUE | FALSE | 15.79 | 0.77 |\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-ioxxvHoicPs"
},
"source": [
"Add `--gradient_checkpointing` flag for around 9.92 GB VRAM usage.\n",
"\n",
"remove `--use_8bit_adam` flag for full precision. Requires 15.79 GB with `--gradient_checkpointing` else 17.8 GB."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jjcSXTp-u-Eg"
},
"outputs": [],
"source": [
"!accelerate launch train_dreambooth.py \\\n",
" --pretrained_model_name_or_path=$MODEL_NAME \\\n",
" --pretrained_vae_name_or_path=\"\" \\\n",
" --instance_data_dir=$INSTANCE_DIR \\\n",
" --class_data_dir=$CLASS_DIR \\\n",
" --output_dir=$OUTPUT_DIR \\\n",
" --revision=\"fp16\" \\\n",
" --with_prior_preservation --prior_loss_weight=1.0 \\\n",
" --resolution=512 \\\n",
" --train_batch_size=4 \\\n",
" --train_text_encoder \\\n",
" --pad_tokens \\\n",
" --mixed_precision=\"fp16\" \\\n",
" --gradient_checkpointing \\\n",
" --use_8bit_adam \\\n",
" --gradient_accumulation_steps=1 \\\n",
" --learning_rate=1e-6 \\\n",
" --lr_scheduler=\"constant\" \\\n",
" --lr_warmup_steps=0 \\\n",
" --num_class_images=60 \\\n",
" --sample_batch_size=1 \\\n",
" --max_train_steps=300 \\\n",
" --instance_prompt=\"{INSTANCE_PROMPT}\" \\\n",
" --class_prompt=\"{CLASS_NAME}\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5V8wgU0HN-Kq"
},
"source": [
"## Convert weights to ckpt to use in web UIs like AUTOMATIC1111."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "89Az5NUxOWdy"
},
"outputs": [],
"source": [
"#@markdown Run conversion.\n",
"from natsort import natsorted\n",
"from glob import glob\n",
"import os\n",
"WEIGHTS_DIR = natsorted(glob(OUTPUT_DIR + os.sep + \"*\"))[-1]\n",
"print(f\"[*] WEIGHTS_DIR={WEIGHTS_DIR}\")\n",
"\n",
"ckpt_path = WEIGHTS_DIR + \"/model.ckpt\"\n",
"\n",
"half_arg = \"\"\n",
"#@markdown Whether to convert to fp16, takes half the space (2GB).\n",
"fp16 = True #@param {type: \"boolean\"}\n",
"if fp16:\n",
" half_arg = \"--half\"\n",
"!python convert_diffusers_to_original_stable_diffusion.py --model_path $WEIGHTS_DIR --checkpoint_path $ckpt_path $half_arg\n",
"print(f\"[*] Converted ckpt saved at {ckpt_path}\")"
]
},
{
"cell_type": "code",
"source": [
"#@title Save ckpt to google drive\n",
"from google.colab import drive\n",
"drive.mount('/content/drive')\n",
"\n",
"import os\n",
"model_checkpoints = \"/content/drive/MyDrive/sd/stable-diffusion-webui/models/Stable-diffusion\"\n",
"os.makedirs(model_checkpoints, exist_ok=True)\n",
"!cp {ckpt_path} {model_checkpoints}"
],
"metadata": {
"cellView": "form",
"id": "MYDjfXf8MB2R"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ToNG4fd_dTbF"
},
"source": [
"## Inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gW15FjffdTID"
},
"outputs": [],
"source": [
"import torch\n",
"from torch import autocast\n",
"from diffusers import StableDiffusionPipeline, DDIMScheduler\n",
"from IPython.display import display\n",
"\n",
"model_path = WEIGHTS_DIR # If you want to use previously trained model saved in gdrive, replace this with the full path of model in gdrive\n",
"\n",
"scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False)\n",
"pipe = StableDiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16).to(\"cuda\")\n",
"\n",
"g_cuda = None"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "oIzkltjpVO_f"
},
"outputs": [],
"source": [
"#@markdown Can set random seed here for reproducibility.\n",
"g_cuda = torch.Generator(device='cuda')\n",
"seed = 52362 #@param {type:\"number\"}\n",
"g_cuda.manual_seed(seed)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "K6xoHWSsbcS3",
"cellView": "form"
},
"outputs": [],
"source": [
"#@title Run for generating images.\n",
"\n",
"prompt = \"sks 1boy in a bucket\" #@param {type:\"string\"}\n",
"negative_prompt = \"\" #@param {type:\"string\"}\n",
"num_samples = 4 #@param {type:\"number\"}\n",
"guidance_scale = 7.5 #@param {type:\"number\"}\n",
"num_inference_steps = 50 #@param {type:\"number\"}\n",
"height = 512 #@param {type:\"number\"}\n",
"width = 512 #@param {type:\"number\"}\n",
"\n",
"with autocast(\"cuda\"), torch.inference_mode():\n",
" images = pipe(\n",
" prompt,\n",
" height=height,\n",
" width=width,\n",
" negative_prompt=negative_prompt,\n",
" num_images_per_prompt=num_samples,\n",
" num_inference_steps=num_inference_steps,\n",
" guidance_scale=guidance_scale,\n",
" generator=g_cuda\n",
" ).images\n",
"\n",
"for img in images:\n",
" display(img)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "WMCqQ5Tcdsm2"
},
"outputs": [],
"source": [
"#@markdown Run Gradio UI for generating images.\n",
"import gradio as gr\n",
"\n",
"def inference(prompt, negative_prompt, num_samples, height=512, width=512, num_inference_steps=50, guidance_scale=7.5):\n",
" with torch.autocast(\"cuda\"), torch.inference_mode():\n",
" return pipe(\n",
" prompt, height=int(height), width=int(width),\n",
" negative_prompt=negative_prompt,\n",
" num_images_per_prompt=int(num_samples),\n",
" num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,\n",
" generator=g_cuda\n",
" ).images\n",
"\n",
"with gr.Blocks() as demo:\n",
" with gr.Row():\n",
" with gr.Column():\n",
" prompt = gr.Textbox(label=\"Prompt\", value=\"photo of sks dog in a bucket\")\n",
" negative_prompt = gr.Textbox(label=\"Negative Prompt\", value=\"\")\n",
" run = gr.Button(value=\"Generate\")\n",
" with gr.Row():\n",
" num_samples = gr.Number(label=\"Number of Samples\", value=4)\n",
" guidance_scale = gr.Number(label=\"Guidance Scale\", value=7.5)\n",
" with gr.Row():\n",
" height = gr.Number(label=\"Height\", value=512)\n",
" width = gr.Number(label=\"Width\", value=512)\n",
" num_inference_steps = gr.Slider(label=\"Steps\", value=50)\n",
" with gr.Column():\n",
" gallery = gr.Gallery()\n",
"\n",
" run.click(inference, inputs=[prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery)\n",
"\n",
"demo.launch(debug=True)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [
"XfTlc8Mqb8iH"
],
"provenance": [],
"private_outputs": true,
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3.8.12 ('pytorch')",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.8.12"
},
"vscode": {
"interpreter": {
"hash": "2d58e898dde0263bc564c6968b04150abacfd33eed9b19aaa8e45c040360e146"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment