Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Created February 19, 2024 21:30
Show Gist options
  • Save CoffeeVampir3/610e4627042ac8f36b45da6ec3af776f to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/610e4627042ac8f36b45da6ec3af776f to your computer and use it in GitHub Desktop.
vae-preview for SDXL
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d62240a6-6331-41fa-8b89-bc16e72a4425",
"metadata": {},
"outputs": [],
"source": [
"import lovely_tensors as lt\n",
"lt.monkey_patch()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7c693e12-6922-4e7e-b610-4e343a05c108",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-08-21 13:57:23.739433: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"2023-08-21 13:57:23.790664: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
"2023-08-21 13:57:24.044699: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
}
],
"source": [
"import torch\n",
"from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline\n",
"from PIL import Image\n",
"import os, gc, random, sys, json, random, time\n",
"\n",
"class NoWatermarker:\n",
" def __init__(self):\n",
" pass\n",
"\n",
" def apply_watermark(self, images: torch.FloatTensor):\n",
" return images #Fake watermarker that bypasses watermarking.\n",
"\n",
"global LOADED_PIPE\n",
"global LOADED_MODEL_PATH\n",
"LOADED_PIPE = None\n",
"LOADED_MODEL_PATH = None\n",
"def load_diffusers_pipe(model_path, scheduler, device):\n",
" global LOADED_PIPE\n",
" global LOADED_MODEL_PATH\n",
" torch.set_float32_matmul_precision('high')\n",
" \n",
" #Check if there's already a loaded pipe that matches what kind of pipe we want.\n",
" if (LOADED_MODEL_PATH and\n",
" LOADED_MODEL_PATH == model_path and\n",
" LOADED_PIPE):\n",
" \n",
" #Check if the scheduler is the correct one we want\n",
" if ((not LOADED_PIPE.scheduler) or\n",
" LOADED_PIPE.scheduler.__class__.__name__ != scheduler.__class__.__name__):\n",
" load_scheduler(LOADED_PIPE, scheduler)\n",
" \n",
" return LOADED_PIPE\n",
" \n",
" #load new pipe\n",
" try:\n",
" LOADED_PIPE = StableDiffusionXLPipeline.from_single_file(\n",
" model_path, \n",
" torch_dtype=torch.float16, \n",
" use_safetensors=True, \n",
" variant=\"fp16\",\n",
" add_watermarker=False)\n",
" LOADED_PIPE.watermark = NoWatermarker()\n",
" \n",
" LOADED_PIPE.enable_vae_tiling()\n",
" LOADED_PIPE.to(device)\n",
" #if device != \"cpu\":\n",
" # LOADED_PIPE.enable_sequential_cpu_offload()\n",
" LOADED_MODEL_PATH = model_path\n",
" except Exception as e:\n",
" print(f\"Error loading the model: {e}\")\n",
" \n",
" return LOADED_PIPE"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2d618b68-aa75-49a0-90cb-d5a982f873b1",
"metadata": {},
"outputs": [],
"source": [
"pipe = load_diffusers_pipe(\"models/sd_xl_base_1.0_0.9vae.safetensors\", None, \"cuda\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e29eaad2-cfd9-496b-84f3-382ca869b6cc",
"metadata": {},
"outputs": [],
"source": [
"@torch.no_grad()\n",
"def tokenize(text, tokenizer, text_encoder, device):\n",
" text_inputs = tokenizer(text, padding=\"max_length\", max_length=tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n",
" text_input_ids = text_inputs.input_ids\n",
" untruncated_ids = tokenizer(text, padding=\"longest\", return_tensors=\"pt\").input_ids\n",
"\n",
" prompt_embeds = text_encoder(\n",
" text_input_ids.to(device),\n",
" output_hidden_states=True,\n",
" )\n",
" return prompt_embeds[0], prompt_embeds.hidden_states[-2]\n",
"\n",
"@torch.no_grad()\n",
"def encode(\n",
" positive_prompt,\n",
" negative_prompt,\n",
" positive_keywords,\n",
" negative_keywords,\n",
" clip_encoder,\n",
" clip_tokenizer,\n",
" openclip_encoder,\n",
" openclip_tokenizer,\n",
" repeats=4,\n",
"):\n",
" device = \"cuda\"\n",
"\n",
" _, pos_clip = tokenize(positive_keywords, clip_tokenizer, clip_encoder, device)\n",
" pos_pool, pos_openclip = tokenize(positive_prompt, openclip_tokenizer, openclip_encoder, device)\n",
"\n",
" _, neg_clip = tokenize(negative_keywords, clip_tokenizer, clip_encoder, device)\n",
" neg_pool, neg_openclip = tokenize(negative_prompt, openclip_tokenizer, openclip_encoder, device)\n",
"\n",
" pos_encodings = [\n",
" pos_clip,\n",
" pos_openclip\n",
" ]\n",
"\n",
" neg_encodings = [\n",
" neg_clip,\n",
" neg_openclip\n",
" ]\n",
"\n",
" positives = torch.concat(pos_encodings, dim=-1).repeat(repeats, 1, 1)\n",
" negatives = torch.concat(neg_encodings, dim=-1).repeat(repeats, 1, 1)\n",
" pos_pool = pos_pool.repeat(repeats, 1, 1).view(repeats, -1)\n",
" neg_pool = neg_pool.repeat(repeats, 1, 1).view(repeats, -1)\n",
" #encodings = [encoding.view(4, encoding.shape[1], -1) for encoding in encodings]\n",
" #prompt_embeds = torch.concat(embeddings, dim=-1)\n",
" #print(prompt_embeds)\n",
" return positives, negatives, pos_pool, neg_pool"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "36453f50-e900-42a3-aa78-1261d825ee9d",
"metadata": {},
"outputs": [],
"source": [
"#pipe.load_lora_weights(\"loras/helltaker-000002.safetensors\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a3475024-bef5-4a34-ba57-eb31e12ab463",
"metadata": {},
"outputs": [],
"source": [
"pos, neg, pos_pool, neg_pool = encode(\n",
" \"Cute dog\", \n",
" \"Poorly drawn, unrealistic, not cute, not colorful\", \n",
" \"Cute dog\", \n",
" \"Poorly drawn, unrealistic, not cute, not colorful\", \n",
" pipe.text_encoder, pipe.tokenizer, pipe.text_encoder_2, pipe.tokenizer_2, repeats=1)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "79fd6965-f2aa-4f23-a898-4e6de8ff30fd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7fb90a805df0>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generator = torch.Generator(\"cuda\")\n",
"generator.manual_seed(1337)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "94095b2f-e550-443a-bdb9-bc4a812b0e44",
"metadata": {},
"outputs": [],
"source": [
"from diffusers import AutoencoderTiny\n",
"from PIL import Image\n",
"from IPython.display import display\n",
"import ipywidgets as widgets\n",
"import io\n",
"import numpy as np\n",
"global TINY_AUTOENCODER\n",
"#TINY_AUTOENCODER = AutoencoderTiny.from_pretrained(\"vaes/taesdxl\", ).to(\"cuda\", torch.float16)\n",
"TINY_AUTOENCODER = AutoencoderTiny.from_pretrained(\"vaes/taesdxl\", torch_dtype=torch.float16)\n",
"TINY_AUTOENCODER.to(\"cuda\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8e658f24-7195-4cc7-b0e1-d179335029fa",
"metadata": {},
"outputs": [],
"source": [
"def decode_callback(step, t, latents):\n",
" latent = latents\n",
" img = TINY_AUTOENCODER.decode(latents)\n",
" img_np = img[0].squeeze(0).permute(1, 2, 0).cpu().detach().numpy().astype('float32')\n",
" img_np = np.clip((img_np + 1) / 2.0, 0, 1)\n",
" image_widget.value = to_png_image(img_np)\n",
"\n",
"def to_png_image(img_np):\n",
" \"\"\"Convert a numpy array to PNG format image.\"\"\"\n",
" img = Image.fromarray((img_np * 255).astype(np.uint8))\n",
" buf = io.BytesIO()\n",
" img.save(buf, format='png', compress_level=0)\n",
" return buf.getvalue()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b6d8250-649e-43ad-b0fa-fe7f18f5ed06",
"metadata": {},
"outputs": [],
"source": [
"image_widget = widgets.Image(format='png')\n",
"display(image_widget)\n",
"\n",
"images = pipe(callback=decode_callback, prompt_embeds = pos, negative_prompt_embeds = neg, pooled_prompt_embeds=pos_pool, negative_pooled_prompt_embeds=neg_pool, num_inference_steps=20, generator=generator).images"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "864a3ec9-9765-471a-b351-411b14317b84",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(20,20))\n",
"for img, ax in zip(images, axs.flatten()):\n",
" ax.imshow(img)\n",
" ax.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9da282ae-c9df-43ac-ba9f-aa898153c31a",
"metadata": {},
"outputs": [],
"source": [
"image = images[0]\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment