Created
February 19, 2024 21:30
-
-
Save CoffeeVampir3/610e4627042ac8f36b45da6ec3af776f to your computer and use it in GitHub Desktop.
vae-preview for SDXL
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "d62240a6-6331-41fa-8b89-bc16e72a4425", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import lovely_tensors as lt\n", | |
"lt.monkey_patch()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "7c693e12-6922-4e7e-b610-4e343a05c108", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"2023-08-21 13:57:23.739433: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", | |
"2023-08-21 13:57:23.790664: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", | |
"2023-08-21 13:57:24.044699: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", | |
"To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline\n", | |
"from PIL import Image\n", | |
"import os, gc, random, sys, json, random, time\n", | |
"\n", | |
"class NoWatermarker:\n", | |
" def __init__(self):\n", | |
" pass\n", | |
"\n", | |
" def apply_watermark(self, images: torch.FloatTensor):\n", | |
" return images #Fake watermarker that bypasses watermarking.\n", | |
"\n", | |
"global LOADED_PIPE\n", | |
"global LOADED_MODEL_PATH\n", | |
"LOADED_PIPE = None\n", | |
"LOADED_MODEL_PATH = None\n", | |
"def load_diffusers_pipe(model_path, scheduler, device):\n", | |
" global LOADED_PIPE\n", | |
" global LOADED_MODEL_PATH\n", | |
" torch.set_float32_matmul_precision('high')\n", | |
" \n", | |
" #Check if there's already a loaded pipe that matches what kind of pipe we want.\n", | |
" if (LOADED_MODEL_PATH and\n", | |
" LOADED_MODEL_PATH == model_path and\n", | |
" LOADED_PIPE):\n", | |
" \n", | |
" #Check if the scheduler is the correct one we want\n", | |
" if ((not LOADED_PIPE.scheduler) or\n", | |
" LOADED_PIPE.scheduler.__class__.__name__ != scheduler.__class__.__name__):\n", | |
" load_scheduler(LOADED_PIPE, scheduler)\n", | |
" \n", | |
" return LOADED_PIPE\n", | |
" \n", | |
" #load new pipe\n", | |
" try:\n", | |
" LOADED_PIPE = StableDiffusionXLPipeline.from_single_file(\n", | |
" model_path, \n", | |
" torch_dtype=torch.float16, \n", | |
" use_safetensors=True, \n", | |
" variant=\"fp16\",\n", | |
" add_watermarker=False)\n", | |
" LOADED_PIPE.watermark = NoWatermarker()\n", | |
" \n", | |
" LOADED_PIPE.enable_vae_tiling()\n", | |
" LOADED_PIPE.to(device)\n", | |
" #if device != \"cpu\":\n", | |
" # LOADED_PIPE.enable_sequential_cpu_offload()\n", | |
" LOADED_MODEL_PATH = model_path\n", | |
" except Exception as e:\n", | |
" print(f\"Error loading the model: {e}\")\n", | |
" \n", | |
" return LOADED_PIPE" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "2d618b68-aa75-49a0-90cb-d5a982f873b1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pipe = load_diffusers_pipe(\"models/sd_xl_base_1.0_0.9vae.safetensors\", None, \"cuda\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "e29eaad2-cfd9-496b-84f3-382ca869b6cc", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@torch.no_grad()\n", | |
"def tokenize(text, tokenizer, text_encoder, device):\n", | |
" text_inputs = tokenizer(text, padding=\"max_length\", max_length=tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n", | |
" text_input_ids = text_inputs.input_ids\n", | |
" untruncated_ids = tokenizer(text, padding=\"longest\", return_tensors=\"pt\").input_ids\n", | |
"\n", | |
" prompt_embeds = text_encoder(\n", | |
" text_input_ids.to(device),\n", | |
" output_hidden_states=True,\n", | |
" )\n", | |
" return prompt_embeds[0], prompt_embeds.hidden_states[-2]\n", | |
"\n", | |
"@torch.no_grad()\n", | |
"def encode(\n", | |
" positive_prompt,\n", | |
" negative_prompt,\n", | |
" positive_keywords,\n", | |
" negative_keywords,\n", | |
" clip_encoder,\n", | |
" clip_tokenizer,\n", | |
" openclip_encoder,\n", | |
" openclip_tokenizer,\n", | |
" repeats=4,\n", | |
"):\n", | |
" device = \"cuda\"\n", | |
"\n", | |
" _, pos_clip = tokenize(positive_keywords, clip_tokenizer, clip_encoder, device)\n", | |
" pos_pool, pos_openclip = tokenize(positive_prompt, openclip_tokenizer, openclip_encoder, device)\n", | |
"\n", | |
" _, neg_clip = tokenize(negative_keywords, clip_tokenizer, clip_encoder, device)\n", | |
" neg_pool, neg_openclip = tokenize(negative_prompt, openclip_tokenizer, openclip_encoder, device)\n", | |
"\n", | |
" pos_encodings = [\n", | |
" pos_clip,\n", | |
" pos_openclip\n", | |
" ]\n", | |
"\n", | |
" neg_encodings = [\n", | |
" neg_clip,\n", | |
" neg_openclip\n", | |
" ]\n", | |
"\n", | |
" positives = torch.concat(pos_encodings, dim=-1).repeat(repeats, 1, 1)\n", | |
" negatives = torch.concat(neg_encodings, dim=-1).repeat(repeats, 1, 1)\n", | |
" pos_pool = pos_pool.repeat(repeats, 1, 1).view(repeats, -1)\n", | |
" neg_pool = neg_pool.repeat(repeats, 1, 1).view(repeats, -1)\n", | |
" #encodings = [encoding.view(4, encoding.shape[1], -1) for encoding in encodings]\n", | |
" #prompt_embeds = torch.concat(embeddings, dim=-1)\n", | |
" #print(prompt_embeds)\n", | |
" return positives, negatives, pos_pool, neg_pool" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "36453f50-e900-42a3-aa78-1261d825ee9d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#pipe.load_lora_weights(\"loras/helltaker-000002.safetensors\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "a3475024-bef5-4a34-ba57-eb31e12ab463", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pos, neg, pos_pool, neg_pool = encode(\n", | |
" \"Cute dog\", \n", | |
" \"Poorly drawn, unrealistic, not cute, not colorful\", \n", | |
" \"Cute dog\", \n", | |
" \"Poorly drawn, unrealistic, not cute, not colorful\", \n", | |
" pipe.text_encoder, pipe.tokenizer, pipe.text_encoder_2, pipe.tokenizer_2, repeats=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "79fd6965-f2aa-4f23-a898-4e6de8ff30fd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<torch._C.Generator at 0x7fb90a805df0>" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"generator = torch.Generator(\"cuda\")\n", | |
"generator.manual_seed(1337)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "94095b2f-e550-443a-bdb9-bc4a812b0e44", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from diffusers import AutoencoderTiny\n", | |
"from PIL import Image\n", | |
"from IPython.display import display\n", | |
"import ipywidgets as widgets\n", | |
"import io\n", | |
"import numpy as np\n", | |
"global TINY_AUTOENCODER\n", | |
"#TINY_AUTOENCODER = AutoencoderTiny.from_pretrained(\"vaes/taesdxl\", ).to(\"cuda\", torch.float16)\n", | |
"TINY_AUTOENCODER = AutoencoderTiny.from_pretrained(\"vaes/taesdxl\", torch_dtype=torch.float16)\n", | |
"TINY_AUTOENCODER.to(\"cuda\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "8e658f24-7195-4cc7-b0e1-d179335029fa", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def decode_callback(step, t, latents):\n", | |
" latent = latents\n", | |
" img = TINY_AUTOENCODER.decode(latents)\n", | |
" img_np = img[0].squeeze(0).permute(1, 2, 0).cpu().detach().numpy().astype('float32')\n", | |
" img_np = np.clip((img_np + 1) / 2.0, 0, 1)\n", | |
" image_widget.value = to_png_image(img_np)\n", | |
"\n", | |
"def to_png_image(img_np):\n", | |
" \"\"\"Convert a numpy array to PNG format image.\"\"\"\n", | |
" img = Image.fromarray((img_np * 255).astype(np.uint8))\n", | |
" buf = io.BytesIO()\n", | |
" img.save(buf, format='png', compress_level=0)\n", | |
" return buf.getvalue()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "8b6d8250-649e-43ad-b0fa-fe7f18f5ed06", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"image_widget = widgets.Image(format='png')\n", | |
"display(image_widget)\n", | |
"\n", | |
"images = pipe(callback=decode_callback, prompt_embeds = pos, negative_prompt_embeds = neg, pooled_prompt_embeds=pos_pool, negative_pooled_prompt_embeds=neg_pool, num_inference_steps=20, generator=generator).images" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "864a3ec9-9765-471a-b351-411b14317b84", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(20,20))\n", | |
"for img, ax in zip(images, axs.flatten()):\n", | |
" ax.imshow(img)\n", | |
" ax.axis('off')\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "9da282ae-c9df-43ac-ba9f-aa898153c31a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"image = images[0]\n" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment