Skip to content

Instantly share code, notes, and snippets.

@wfjsw
Last active November 24, 2022 00:50
Show Gist options
  • Save wfjsw/2b2a26349bef1ce891f6ab4d4fb3030a to your computer and use it in GitHub Desktop.
Save wfjsw/2b2a26349bef1ce891f6ab4d4fb3030a to your computer and use it in GitHub Desktop.
convert-pt-embedding-to-png.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"private_outputs": true,
"authorship_tag": "ABX9TyOObjGuSSNAzbXIxhUQiolx",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/wfjsw/2b2a26349bef1ce891f6ab4d4fb3030a/convert-pt-embedding-to-png.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# 将 pt 格式的训练模型文件转换为 png 格式\n",
"\n",
"**注意:本工具仅限于处理 500 KB 以下的 Texture Inversion Embedding 模型。不能用来处理 80 MB 以上的 Hypernetwork 模型。**\n",
"\n",
"转换后的 `.PNG` 文件可代替 原 `.pt` 文件在 `embeddings` 目录使用。\n",
"\n",
"## 准备\n",
"\n",
"1. .pt 格式的训练模型文件\n",
"2. 一个 512x512 大小的 PNG 图片用于演示图\n",
"\n",
"## 运行\n",
"\n",
"按顺序运行下方的两个单元格。第一个单元格执行完毕后会有提示重启笔记本,点击提示下方的按钮重启。\n",
"\n",
"执行第二个单元格时,在提示出现时依次上传模型文件和演示图。转换完成后会自动下载结果。"
],
"metadata": {
"id": "ytG1_u2pmgzH"
}
},
{
"cell_type": "code",
"source": [
"#@title 安装软件包\n",
"!apt install -y -q fonts-roboto\n",
"!pip install numpy==1.21.6 Pillow==9.2.0 "
],
"metadata": {
"cellView": "form",
"id": "GOp2FjbdmCL3"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 转换\n",
"\n",
"from PIL import Image, PngImagePlugin, ImageDraw, ImageFont\n",
"import json\n",
"import base64\n",
"import numpy as np\n",
"import zlib\n",
"import torch\n",
"from google.colab import files\n",
"\n",
"class EmbeddingEncoder(json.JSONEncoder):\n",
" def default(self, obj):\n",
" if isinstance(obj, torch.Tensor):\n",
" return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}\n",
" return json.JSONEncoder.default(self, obj)\n",
"\n",
"def embedding_to_b64(data):\n",
" d = json.dumps(data, cls=EmbeddingEncoder)\n",
" return base64.b64encode(d.encode())\n",
"\n",
"def lcg(m=2**32, a=1664525, c=1013904223, seed=0):\n",
" while True:\n",
" seed = (a * seed + c) % m\n",
" yield seed % 255\n",
"\n",
"def xor_block(block):\n",
" g = lcg()\n",
" randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)\n",
" return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)\n",
"\n",
"\n",
"def style_block(block, sequence):\n",
" im = Image.new('RGB', (block.shape[1], block.shape[0]))\n",
" draw = ImageDraw.Draw(im)\n",
" i = 0\n",
" for x in range(-6, im.size[0], 8):\n",
" for yi, y in enumerate(range(-6, im.size[1], 8)):\n",
" offset = 0\n",
" if yi % 2 == 0:\n",
" offset = 4\n",
" shade = sequence[i % len(sequence)]\n",
" i += 1\n",
" draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))\n",
"\n",
" fg = np.array(im).astype(np.uint8) & 0xF0\n",
"\n",
" return block ^ fg\n",
"\n",
"def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):\n",
" from math import cos\n",
"\n",
" image = srcimage.copy()\n",
"\n",
" if textfont is None:\n",
" try:\n",
" textfont = ImageFont.truetype('/usr/share/fonts/truetype/roboto/Roboto-Regular.ttf', fontsize)\n",
" textfont = '/usr/share/fonts/truetype/roboto/Roboto-Regular.ttf'\n",
" except Exception:\n",
" textfont = '/usr/share/fonts/truetype/roboto/Roboto-Regular.ttf'\n",
"\n",
" factor = 1.5\n",
" gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))\n",
" for y in range(image.size[1]):\n",
" mag = 1-cos(y/image.size[1]*factor)\n",
" mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))\n",
" gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))\n",
" image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))\n",
"\n",
" draw = ImageDraw.Draw(image)\n",
" fontsize = 32\n",
" font = ImageFont.truetype(textfont, fontsize)\n",
" padding = 10\n",
"\n",
" _, _, w, h = draw.textbbox((0, 0), title, font=font)\n",
" fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)\n",
" font = ImageFont.truetype(textfont, fontsize)\n",
" _, _, w, h = draw.textbbox((0, 0), title, font=font)\n",
" draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))\n",
"\n",
" _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)\n",
" fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)\n",
" _, _, w, h = draw.textbbox((0, 0), footerMid, font=font)\n",
" fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)\n",
" _, _, w, h = draw.textbbox((0, 0), footerRight, font=font)\n",
" fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)\n",
"\n",
" font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))\n",
"\n",
" draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))\n",
" draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))\n",
" draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))\n",
"\n",
" return image\n",
"\n",
"def insert_image_data_embed(image, data):\n",
" d = 3\n",
" data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)\n",
" data_np_ = np.frombuffer(data_compressed, np.uint8).copy()\n",
" data_np_high = data_np_ >> 4\n",
" data_np_low = data_np_ & 0x0F\n",
"\n",
" h = image.size[1]\n",
" next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))\n",
" next_size = next_size + ((h*d)-(next_size % (h*d)))\n",
"\n",
" data_np_low.resize(next_size)\n",
" data_np_low = data_np_low.reshape((h, -1, d))\n",
"\n",
" data_np_high.resize(next_size)\n",
" data_np_high = data_np_high.reshape((h, -1, d))\n",
"\n",
" edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]\n",
" edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)\n",
"\n",
" data_np_low = style_block(data_np_low, sequence=edge_style)\n",
" data_np_low = xor_block(data_np_low)\n",
" data_np_high = style_block(data_np_high, sequence=edge_style[::-1])\n",
" data_np_high = xor_block(data_np_high)\n",
"\n",
" im_low = Image.fromarray(data_np_low, mode='RGB')\n",
" im_high = Image.fromarray(data_np_high, mode='RGB')\n",
"\n",
" background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))\n",
" background.paste(im_low, (0, 0))\n",
" background.paste(image, (im_low.size[0]+1, 0))\n",
" background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))\n",
"\n",
" return background\n",
"\n",
"\n",
"print(\"选择你要上传的 .pt 文件\")\n",
"uploaded = files.upload_file('embedding.pt')\n",
"\n",
"print(\"选择一张 512x512 PNG 效果图\")\n",
"uploaded = files.upload_file('cover.png')\n",
"\n",
"img = Image.open('cover.png')\n",
"\n",
"if img.width != 512 or img.height != 512:\n",
" print(\"PNG 图片体积非 512x512\")\n",
" exit(1)\n",
"\n",
"info = PngImagePlugin.PngInfo()\n",
"data = torch.load('embedding.pt', map_location='cpu')\n",
"\n",
"# to allow the usage of filename\n",
"del data['name']\n",
"\n",
"info.add_text(\"sd-ti-embedding\", embedding_to_b64(data))\n",
"\n",
"# title = \"<{}>\".format(data.get('name', 'Unknown') or 'Unknown')\n",
"title = \"\"\n",
"\n",
"try:\n",
" vectorSize = list(data['string_to_param'].values())[0].shape[0]\n",
"except Exception as e:\n",
" vectorSize = '?'\n",
"\n",
"# footer_left = data.get('sd_checkpoint_name', 'Unknown') or 'Unknown'\n",
"footer_left = '[{}]'.format(data.get('sd_checkpoint', 'unknown') or 'unknown')\n",
"footer_mid = \"\"\n",
"footer_right = '{}v {}s'.format(vectorSize, data.get('step', '?') or '?')\n",
"\n",
"captioned_image = caption_image_overlay(img, title, footer_left, footer_mid, footer_right)\n",
"captioned_image = insert_image_data_embed(captioned_image, data)\n",
"\n",
"captioned_image.save('embedding.png', \"PNG\", pnginfo=info)\n",
"display(captioned_image)\n",
"\n",
"img.close()\n",
"captioned_image.close()\n",
"files.download('embedding.png')"
],
"metadata": {
"id": "b3zADOHI6jkh"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment