Last active
November 24, 2022 00:50
-
-
Save wfjsw/2b2a26349bef1ce891f6ab4d4fb3030a to your computer and use it in GitHub Desktop.
convert-pt-embedding-to-png.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"private_outputs": true, | |
"authorship_tag": "ABX9TyOObjGuSSNAzbXIxhUQiolx", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/wfjsw/2b2a26349bef1ce891f6ab4d4fb3030a/convert-pt-embedding-to-png.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# 将 pt 格式的训练模型文件转换为 png 格式\n", | |
"\n", | |
"**注意:本工具仅限于处理 500 KB 以下的 Texture Inversion Embedding 模型。不能用来处理 80 MB 以上的 Hypernetwork 模型。**\n", | |
"\n", | |
"转换后的 `.PNG` 文件可代替 原 `.pt` 文件在 `embeddings` 目录使用。\n", | |
"\n", | |
"## 准备\n", | |
"\n", | |
"1. .pt 格式的训练模型文件\n", | |
"2. 一个 512x512 大小的 PNG 图片用于演示图\n", | |
"\n", | |
"## 运行\n", | |
"\n", | |
"按顺序运行下方的两个单元格。第一个单元格执行完毕后会有提示重启笔记本,点击提示下方的按钮重启。\n", | |
"\n", | |
"执行第二个单元格时,在提示出现时依次上传模型文件和演示图。转换完成后会自动下载结果。" | |
], | |
"metadata": { | |
"id": "ytG1_u2pmgzH" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#@title 安装软件包\n", | |
"!apt install -y -q fonts-roboto\n", | |
"!pip install numpy==1.21.6 Pillow==9.2.0 " | |
], | |
"metadata": { | |
"cellView": "form", | |
"id": "GOp2FjbdmCL3" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#@title 转换\n", | |
"\n", | |
"from PIL import Image, PngImagePlugin, ImageDraw, ImageFont\n", | |
"import json\n", | |
"import base64\n", | |
"import numpy as np\n", | |
"import zlib\n", | |
"import torch\n", | |
"from google.colab import files\n", | |
"\n", | |
"class EmbeddingEncoder(json.JSONEncoder):\n", | |
" def default(self, obj):\n", | |
" if isinstance(obj, torch.Tensor):\n", | |
" return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}\n", | |
" return json.JSONEncoder.default(self, obj)\n", | |
"\n", | |
"def embedding_to_b64(data):\n", | |
" d = json.dumps(data, cls=EmbeddingEncoder)\n", | |
" return base64.b64encode(d.encode())\n", | |
"\n", | |
"def lcg(m=2**32, a=1664525, c=1013904223, seed=0):\n", | |
" while True:\n", | |
" seed = (a * seed + c) % m\n", | |
" yield seed % 255\n", | |
"\n", | |
"def xor_block(block):\n", | |
" g = lcg()\n", | |
" randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)\n", | |
" return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)\n", | |
"\n", | |
"\n", | |
"def style_block(block, sequence):\n", | |
" im = Image.new('RGB', (block.shape[1], block.shape[0]))\n", | |
" draw = ImageDraw.Draw(im)\n", | |
" i = 0\n", | |
" for x in range(-6, im.size[0], 8):\n", | |
" for yi, y in enumerate(range(-6, im.size[1], 8)):\n", | |
" offset = 0\n", | |
" if yi % 2 == 0:\n", | |
" offset = 4\n", | |
" shade = sequence[i % len(sequence)]\n", | |
" i += 1\n", | |
" draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))\n", | |
"\n", | |
" fg = np.array(im).astype(np.uint8) & 0xF0\n", | |
"\n", | |
" return block ^ fg\n", | |
"\n", | |
"def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):\n", | |
" from math import cos\n", | |
"\n", | |
" image = srcimage.copy()\n", | |
"\n", | |
" if textfont is None:\n", | |
" try:\n", | |
" textfont = ImageFont.truetype('/usr/share/fonts/truetype/roboto/Roboto-Regular.ttf', fontsize)\n", | |
" textfont = '/usr/share/fonts/truetype/roboto/Roboto-Regular.ttf'\n", | |
" except Exception:\n", | |
" textfont = '/usr/share/fonts/truetype/roboto/Roboto-Regular.ttf'\n", | |
"\n", | |
" factor = 1.5\n", | |
" gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))\n", | |
" for y in range(image.size[1]):\n", | |
" mag = 1-cos(y/image.size[1]*factor)\n", | |
" mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))\n", | |
" gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))\n", | |
" image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))\n", | |
"\n", | |
" draw = ImageDraw.Draw(image)\n", | |
" fontsize = 32\n", | |
" font = ImageFont.truetype(textfont, fontsize)\n", | |
" padding = 10\n", | |
"\n", | |
" _, _, w, h = draw.textbbox((0, 0), title, font=font)\n", | |
" fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)\n", | |
" font = ImageFont.truetype(textfont, fontsize)\n", | |
" _, _, w, h = draw.textbbox((0, 0), title, font=font)\n", | |
" draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))\n", | |
"\n", | |
" _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)\n", | |
" fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)\n", | |
" _, _, w, h = draw.textbbox((0, 0), footerMid, font=font)\n", | |
" fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)\n", | |
" _, _, w, h = draw.textbbox((0, 0), footerRight, font=font)\n", | |
" fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)\n", | |
"\n", | |
" font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))\n", | |
"\n", | |
" draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))\n", | |
" draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))\n", | |
" draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))\n", | |
"\n", | |
" return image\n", | |
"\n", | |
"def insert_image_data_embed(image, data):\n", | |
" d = 3\n", | |
" data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)\n", | |
" data_np_ = np.frombuffer(data_compressed, np.uint8).copy()\n", | |
" data_np_high = data_np_ >> 4\n", | |
" data_np_low = data_np_ & 0x0F\n", | |
"\n", | |
" h = image.size[1]\n", | |
" next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))\n", | |
" next_size = next_size + ((h*d)-(next_size % (h*d)))\n", | |
"\n", | |
" data_np_low.resize(next_size)\n", | |
" data_np_low = data_np_low.reshape((h, -1, d))\n", | |
"\n", | |
" data_np_high.resize(next_size)\n", | |
" data_np_high = data_np_high.reshape((h, -1, d))\n", | |
"\n", | |
" edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]\n", | |
" edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)\n", | |
"\n", | |
" data_np_low = style_block(data_np_low, sequence=edge_style)\n", | |
" data_np_low = xor_block(data_np_low)\n", | |
" data_np_high = style_block(data_np_high, sequence=edge_style[::-1])\n", | |
" data_np_high = xor_block(data_np_high)\n", | |
"\n", | |
" im_low = Image.fromarray(data_np_low, mode='RGB')\n", | |
" im_high = Image.fromarray(data_np_high, mode='RGB')\n", | |
"\n", | |
" background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))\n", | |
" background.paste(im_low, (0, 0))\n", | |
" background.paste(image, (im_low.size[0]+1, 0))\n", | |
" background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))\n", | |
"\n", | |
" return background\n", | |
"\n", | |
"\n", | |
"print(\"选择你要上传的 .pt 文件\")\n", | |
"uploaded = files.upload_file('embedding.pt')\n", | |
"\n", | |
"print(\"选择一张 512x512 PNG 效果图\")\n", | |
"uploaded = files.upload_file('cover.png')\n", | |
"\n", | |
"img = Image.open('cover.png')\n", | |
"\n", | |
"if img.width != 512 or img.height != 512:\n", | |
" print(\"PNG 图片体积非 512x512\")\n", | |
" exit(1)\n", | |
"\n", | |
"info = PngImagePlugin.PngInfo()\n", | |
"data = torch.load('embedding.pt', map_location='cpu')\n", | |
"\n", | |
"# to allow the usage of filename\n", | |
"del data['name']\n", | |
"\n", | |
"info.add_text(\"sd-ti-embedding\", embedding_to_b64(data))\n", | |
"\n", | |
"# title = \"<{}>\".format(data.get('name', 'Unknown') or 'Unknown')\n", | |
"title = \"\"\n", | |
"\n", | |
"try:\n", | |
" vectorSize = list(data['string_to_param'].values())[0].shape[0]\n", | |
"except Exception as e:\n", | |
" vectorSize = '?'\n", | |
"\n", | |
"# footer_left = data.get('sd_checkpoint_name', 'Unknown') or 'Unknown'\n", | |
"footer_left = '[{}]'.format(data.get('sd_checkpoint', 'unknown') or 'unknown')\n", | |
"footer_mid = \"\"\n", | |
"footer_right = '{}v {}s'.format(vectorSize, data.get('step', '?') or '?')\n", | |
"\n", | |
"captioned_image = caption_image_overlay(img, title, footer_left, footer_mid, footer_right)\n", | |
"captioned_image = insert_image_data_embed(captioned_image, data)\n", | |
"\n", | |
"captioned_image.save('embedding.png', \"PNG\", pnginfo=info)\n", | |
"display(captioned_image)\n", | |
"\n", | |
"img.close()\n", | |
"captioned_image.close()\n", | |
"files.download('embedding.png')" | |
], | |
"metadata": { | |
"id": "b3zADOHI6jkh" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment