Created
August 19, 2021 02:55
-
-
Save dribnet/5c222d46e3a1d9813db69b288da83311 to your computer and use it in GitHub Desktop.
Public_CLIP_+_Chunky_RGB_Optimization_v0_1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Public CLIP + Chunky RGB Optimization v0.1", | |
"provenance": [], | |
"collapsed_sections": [] | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "U8rXELYhVtAG" | |
}, | |
"source": [ | |
"!nvidia-smi" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WcuyIX-kdpre" | |
}, | |
"source": [ | |
"!pip -qq install av\n", | |
"!pip -qq install torch_optimizer" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "UMWyysdmVkPF" | |
}, | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"import torchvision\n", | |
"import torchvision.transforms as T\n", | |
"import torchvision.transforms.functional as TF\n", | |
"import torch_optimizer as optim\n", | |
"import random\n", | |
"import PIL\n", | |
"from PIL import Image\n", | |
"import math\n", | |
"import gc\n", | |
"\n", | |
"torch.cuda.amp.autocast(enabled=True)\n", | |
"\n", | |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
"torch.set_grad_enabled(False)\n", | |
"\n", | |
"def clear_mem():\n", | |
" torch.cuda.empty_cache()\n", | |
" gc.collect()\n", | |
"\n", | |
"ToTensor = T.ToTensor()\n", | |
"ToImage = T.ToPILImage()\n", | |
"\n", | |
"def OpenImage(x, resize=None, convert=\"RGB\"):\n", | |
" if resize:\n", | |
" return ToTensor(Image.open(x).convert(convert).resize(resize)).unsqueeze(0).to(device)\n", | |
" else:\n", | |
" return ToTensor(Image.open(x).convert(convert)).unsqueeze(0).to(device)\n", | |
"\n", | |
"import warnings\n", | |
"warnings.filterwarnings('ignore')" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "KF7gW_SWVlos" | |
}, | |
"source": [ | |
"!pip install --no-deps git+https://github.com/openai/CLIP.git\n", | |
"!pip install --no-deps ftfy regex tqdm\n", | |
"\n", | |
"import clip\n", | |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
"torch.set_grad_enabled(False)\n", | |
"\n", | |
"dtype = torch.half\n", | |
"\n", | |
"perceptor, preprocess = clip.load(\"ViT-B/16\")#, jit=False)\n", | |
"perceptor.eval().requires_grad_(False);\n", | |
"\n", | |
"CLIP_Normalization = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "LHq8EjnPHfnL" | |
}, | |
"source": [ | |
"!wget https://lospec.com/palette-list/dawnbringers-8-color-1x.png -O dawnbringers-8-color.png" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "OuN4V5YWWGYi" | |
}, | |
"source": [ | |
"def diff_abs(x, y=0.0001):\n", | |
" return torch.sqrt(x*x+y)\n", | |
"\n", | |
"def diff_relu(x, y=0.0001):\n", | |
" return (torch.sqrt(x*x+y)+x)*0.5\n", | |
"\n", | |
"def diff_clamp(x, y=0.0001):\n", | |
" return diff_relu(1-diff_relu(1-x, y), y)\n", | |
"\n", | |
"def sign_pow(x, y):\n", | |
" return torch.pow(torch.abs(x), y) * torch.sign(x)\n", | |
"\n", | |
"def gaussian_sigma(x):\n", | |
" return 0.3 * ((x - 1) * 0.5 - 1) + 0.8\n", | |
"\n", | |
"def pseudo_gaussian(x):\n", | |
" x = torch.exp(-0.5*(x*3)**2)\n", | |
" return x\n", | |
"\n", | |
"def gaussian_pdf(x, sigma=0.333):\n", | |
" return torch.exp(-0.5 * (x / sigma).pow(2))\n", | |
"\n", | |
"def edge_preserving_blur(x, width=13, edge_mask_sigma=0.001):\n", | |
" x_blur = TF.gaussian_blur(x, width)\n", | |
" x_edge = (TF.rgb_to_grayscale(x_blur) - TF.rgb_to_grayscale(x))\n", | |
" x_edge = gaussian_pdf(x_edge, edge_mask_sigma)\n", | |
" x_edge = TF.gaussian_blur(x_edge, 3)\n", | |
" x_sub_edge = x * x_edge\n", | |
" x_sub_edge_blur = TF.gaussian_blur(x_sub_edge, width)\n", | |
" x_edge_blur = TF.gaussian_blur(x_edge, width)\n", | |
" result = x_sub_edge_blur / x_edge_blur.add(1e-8)\n", | |
" return result\n", | |
"\n", | |
"gradient = torch.linspace(-1,1,3).unsqueeze(0).tile(3,1).unsqueeze(0)\n", | |
"gradient = torch.cat([gradient, gradient.rot90(1,(-2,-1))]).unsqueeze(1).to(device)\n", | |
"\n", | |
"def gradient_conv(x):\n", | |
" x = TF.pad(x, 1, padding_mode='reflect')\n", | |
" x = F.conv2d(x, gradient)\n", | |
" return x\n", | |
"\n", | |
"def gradient_filter(x):\n", | |
" b, c, h, w = x.shape\n", | |
" x = x.reshape(b*c, 1, h, w)\n", | |
" x = TF.pad(x, 1, padding_mode='reflect')\n", | |
" x = F.conv2d(x, gradient)\n", | |
" b2, c2, h2, w2 = x.shape\n", | |
" x = x.permute(0,2,3,1)\n", | |
" x = x.reshape(b, c, h2, w2, 2)\n", | |
" return x\n", | |
"\n", | |
"def tv_loss(input):\n", | |
" \"\"\"L2 total variation loss, as in Mahendran et al.\"\"\"\n", | |
" input = F.pad(input, (0, 1, 0, 1), 'replicate')\n", | |
" x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]\n", | |
" y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]\n", | |
" return (x_diff**2 + y_diff**2).mean([1, 2, 3])\n", | |
"\n", | |
"def highpass_loss(x, kernel, sigma):\n", | |
" x = x - TF.gaussian_blur(x, kernel, sigma)\n", | |
" x = pseudo_gaussian(x*0.5+0.5).sub(0.5).mul(2) * 0.25 + x * 0.75\n", | |
" x = (x**2).mean([1,2,3])\n", | |
" return x\n", | |
"\n", | |
"def symmetric_nn_loss(img):\n", | |
" rolls = []\n", | |
" width = 3\n", | |
" for x in range(width):\n", | |
" for y in range(width):\n", | |
" roll = torch.roll(img, (x-width//2,y-width//2), dims=(-2,-1)).unsqueeze(0)\n", | |
" rolls.append(roll)\n", | |
" del roll\n", | |
" rolls = torch.cat(rolls, 0)\n", | |
" pair_0 = (rolls[0]-rolls[8]).pow(2).add(1e-8).pow(0.5)\n", | |
" pair_1 = (rolls[1]-rolls[7]).pow(2).add(1e-8).pow(0.5)\n", | |
" pair_2 = (rolls[2]-rolls[6]).pow(2).add(1e-8).pow(0.5)\n", | |
" pair_3 = (rolls[3]-rolls[5]).pow(2).add(1e-8).pow(0.5)\n", | |
" loss = (pair_0 + pair_1 + pair_2 + pair_3) / 4\n", | |
" return loss\n", | |
"\n", | |
"def get_grid(x, center=[0.0, 0.0], angle=0.0, translate=[0.0, 0.0], scale=1.0, shear=[0.0, 0.0]):\n", | |
" matrix = TF._get_inverse_affine_matrix(center, angle, translate, scale, shear)\n", | |
" matrix = torch.tensor(matrix).reshape(1,2,3)\n", | |
" grid = F.affine_grid(matrix, x[0,None].shape, align_corners=False)\n", | |
" return grid\n", | |
"\n", | |
"def funny_img(x, center=[0.0, 0.0], angle=0.0, translate=[0.0, 0.0], scale=1.0, shear=[0.0, 0.0], rand=0.01, bend=1.0, random_crop=224):\n", | |
" with torch.no_grad():\n", | |
" center = torch.tensor(center).to(device)\n", | |
" angle_0 = torch.tensor(random.random()*angle*2-angle).float().to(device)\n", | |
" angle_1 = torch.tensor(random.random()*angle*2-angle).float().to(device)\n", | |
" angle_2 = torch.tensor(random.random()*360).float().to(device)\n", | |
" translate = torch.tensor(translate).to(device)\n", | |
" scale = torch.tensor(scale).to(device)\n", | |
" shear = torch.tensor(shear).to(device)\n", | |
" grid_0 = get_grid(x, center, angle_0, translate, scale, shear)\n", | |
" grid_1 = get_grid(x, center, angle_1, translate, scale, shear) * torch.rand(1,1,1,2).mul(0.1).add(0.9) + torch.rand(1,1,1,2) * 1/32\n", | |
" blob = F.interpolate(torch.randn(1,2,8,8).tanh(), (x.shape[-2], x.shape[-1]), mode='bicubic', align_corners=False).permute(0,2,3,1)\n", | |
" angle_2 = angle_2 * (math.pi/180)\n", | |
" ang_x = math.cos(angle_2)\n", | |
" ang_y = math.sin(angle_2)\n", | |
" ang = torch.tensor([ang_x, ang_y]).reshape(1,1,1,2)\n", | |
" gradient = F.affine_grid(torch.tensor([1.0, 0.0, 0.0, -0.0, 1.0, 0.0]).reshape(1,2,3).float(), x[0,None].shape, align_corners=True)\n", | |
" gradient = (gradient * ang).sum(-1, keepdim=True)\n", | |
" grid_mix = torch.lerp(grid_0, grid_1, gradient*bend).to(device)\n", | |
" grid_mix = torch.lerp(grid_mix, blob.to(device), rand).tile(x.shape[0],1,1,1)\n", | |
" grid_mix = T.RandomCrop(224)(grid_mix.permute(0,3,1,2)).permute(0,2,3,1)\n", | |
" x = F.grid_sample(x, grid_mix, align_corners=False, padding_mode='reflection')\n", | |
" return x\n", | |
"\n", | |
"def soft_clip(x, gain=1.0, mix=1.0):\n", | |
" return torch.lerp(x, x.mul(gain).tanh().div(gain), mix)\n", | |
"\n", | |
"def triangle_blur(x, kernel_size=3, pow=1.0):\n", | |
" padding = (kernel_size-1) // 2\n", | |
" b,c,h,w = x.shape\n", | |
" kernel = torch.linspace(-1,1,kernel_size+2)[1:-1].abs().neg().add(1).reshape(1,1,1,kernel_size).pow(pow).to(device)\n", | |
" kernel = kernel / kernel.sum()\n", | |
" x = x.reshape(b*c,1,h,w)\n", | |
" x = F.pad(x, (padding,padding,padding,padding), mode='reflect')\n", | |
" x = F.conv2d(x, kernel)\n", | |
" x = F.conv2d(x, kernel.permute(0,1,3,2))\n", | |
" x = x.reshape(b,c,h,w)\n", | |
" return x\n", | |
"\n", | |
"down_kernel = torch.tensor([0.25,0.5,0.25]).reshape(1,1,1,3).to(device)\n", | |
"def smooth_downsample(x, reps=1):\n", | |
" b,c,h,w = x.shape\n", | |
" x = x.reshape(b*c,1,h,w)\n", | |
" for _ in range(reps):\n", | |
" x = TF.pad(x, 1, padding_mode='edge')\n", | |
" x = F.conv2d(x, down_kernel)\n", | |
" x = F.conv2d(x, down_kernel.permute(0,1,3,2))\n", | |
" x = F.avg_pool2d(x, 2)\n", | |
" h,w = x.shape[2:]\n", | |
" x = x.reshape(b,c,h,w)\n", | |
" return x\n", | |
"\n", | |
"def quick_blur_2(x, reps=1):\n", | |
" b,c,h,w = x.shape\n", | |
" x = x.reshape(b*c,1,h,w)\n", | |
" for _ in range(reps):\n", | |
" x = TF.pad(x, 1, padding_mode='edge')\n", | |
" x = F.conv2d(x, down_kernel)\n", | |
" x = F.conv2d(x, down_kernel.permute(0,1,3,2))\n", | |
" x = F.avg_pool2d(x, 2)\n", | |
" for _ in range(reps):\n", | |
" x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)\n", | |
" x = TF.pad(x, 1, padding_mode='edge')\n", | |
" x = F.conv2d(x, down_kernel)\n", | |
" x = F.conv2d(x, down_kernel.permute(0,1,3,2))\n", | |
" h,w = x.shape[2:]\n", | |
" x = x.reshape(b,c,h,w)\n", | |
" return x\n", | |
"\n", | |
"softmax2d = nn.Softmax2d()" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xBGP1D-J-i2p" | |
}, | |
"source": [ | |
"def encode_text_averaging(prompt_in):\n", | |
" x_token_embedding = perceptor.token_embedding(prompt_in) # [batch_size, n_ctx, d_model]\n", | |
"\n", | |
" x = x_token_embedding + perceptor.positional_embedding\n", | |
" x = x.permute(1, 0, 2) # NLD -> LND\n", | |
" x = perceptor.transformer(x.half())\n", | |
" x = x.permute(1, 0, 2) # LND -> NLD\n", | |
" x = x / x.norm(dim=0,keepdim=True)\n", | |
" x = x.mean(0,True)\n", | |
" x = perceptor.ln_final(x)\n", | |
"\n", | |
" # x.shape = [batch_size, n_ctx, transformer.width]\n", | |
" # take features from the eot embedding (eot_token is the highest number in each sequence)\n", | |
" text_argmax = prompt_in.argmax(dim=-1)\n", | |
" x = x[torch.arange(x.shape[0]), text_argmax] @ perceptor.text_projection\n", | |
" x = x / x.norm(dim=0,keepdim=True)\n", | |
" x = x.mean(0,True)\n", | |
" return x\n", | |
"\n", | |
"def clip_encode_image_custom(x_raw, keepgrid=False, averageall=False):\n", | |
" b, c, h, w = x_raw.shape\n", | |
" x_raw = x_raw.unfold(2,2,2).unfold(3,2,2).reshape(b,3,224,224,4).permute(4,0,1,2,3) #.reshape(-1,3,224,224)\n", | |
" x_merge = 0.0\n", | |
" for p in range(4):\n", | |
" x = perceptor.visual.conv1(x_raw[p]) # shape = [*, width, grid, grid]\n", | |
" x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]\n", | |
" x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]\n", | |
" x = torch.cat([perceptor.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]\n", | |
" x = x + perceptor.visual.positional_embedding.to(x.dtype)\n", | |
" x = perceptor.visual.ln_pre(x)\n", | |
"\n", | |
" x = x.permute(1, 0, 2) # NLD -> LND\n", | |
" x = perceptor.visual.transformer(x)\n", | |
" x = x.permute(1, 0, 2) # LND -> NLD\n", | |
"\n", | |
" x_merge = x_merge + x\n", | |
" x = x_merge / 4\n", | |
"\n", | |
" if averageall == True:\n", | |
" x = x.mean(0,True)\n", | |
"\n", | |
" if keepgrid == True:\n", | |
" x = perceptor.visual.ln_post(x)\n", | |
" else:\n", | |
" x = perceptor.visual.ln_post(x[:, 0, :])\n", | |
"\n", | |
" if perceptor.visual.proj is not None:\n", | |
" x = x @ perceptor.visual.proj\n", | |
"\n", | |
" return x\n", | |
"\n", | |
"clip_conv1_backup = perceptor.visual.conv1.weight.data.clone().detach()\n", | |
"clip_conv1_32px = F.interpolate(clip_conv1_backup, (32,32), mode='bicubic', align_corners=False) / 4\n", | |
"\n", | |
"def clip_encode_image_big_kernels(x_raw, keepgrid=False, averageall=False):\n", | |
" x = F.conv2d(x_raw, clip_conv1_32px, stride=32) # shape = [*, width, grid, grid]\n", | |
" x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]\n", | |
" x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]\n", | |
" x = torch.cat([perceptor.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]\n", | |
" x = x + perceptor.visual.positional_embedding.to(x.dtype)\n", | |
" x = perceptor.visual.ln_pre(x)\n", | |
"\n", | |
" x = x.permute(1, 0, 2) # NLD -> LND\n", | |
" x = perceptor.visual.transformer(x)\n", | |
" x = x.permute(1, 0, 2) # LND -> NLD\n", | |
"\n", | |
" if averageall == True:\n", | |
" x = x.mean(0,True)\n", | |
"\n", | |
" if keepgrid == True:\n", | |
" x = perceptor.visual.ln_post(x)\n", | |
" else:\n", | |
" x = perceptor.visual.ln_post(x[:, 0, :])\n", | |
"\n", | |
" if perceptor.visual.proj is not None:\n", | |
" x = x @ perceptor.visual.proj\n", | |
"\n", | |
" return x" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Aarxe29M5DuQ" | |
}, | |
"source": [ | |
"augments = T.Compose([\n", | |
" T.Lambda(lambda x: x + torch.randn_like(x) * TF.rgb_to_grayscale(x) * random.random() * 0.01),\n", | |
" T.RandomChoice([\n", | |
" T.Resize(112, T.InterpolationMode.NEAREST),\n", | |
" T.Resize(224, T.InterpolationMode.NEAREST),\n", | |
" ]),\n", | |
" T.Lambda(lambda x: F.interpolate(x, (random.randint(448,480),random.randint(448,480)), mode='bilinear', align_corners=False)),\n", | |
" T.Lambda(lambda x: x + torch.randn_like(x) * TF.rgb_to_grayscale(x).neg().add(1) * random.random() * 0.01),\n", | |
" T.RandomChoice([\n", | |
" T.Lambda(lambda x: x),\n", | |
" T.Lambda(lambda x: TF.gaussian_blur(x, 3, 1/math.pi)),\n", | |
" T.Lambda(lambda x: TF.gaussian_blur(x, 5, 2/math.pi)),\n", | |
" T.Lambda(lambda x: TF.gaussian_blur(x, 5, 4/math.pi)),\n", | |
" T.Lambda(lambda x: TF.gaussian_blur(x, 7, 6/math.pi)),\n", | |
" ]),\n", | |
" T.Pad(8, padding_mode='edge'),\n", | |
" T.RandomRotation(9, T.InterpolationMode.BILINEAR, True),\n", | |
" T.Lambda(lambda x: x + torch.randn_like(x) * TF.rgb_to_grayscale(x) * random.random() * 0.005),\n", | |
" T.Lambda(lambda x: F.interpolate(x, (random.randint(448,480),random.randint(448,480)), mode='bilinear', align_corners=False)),\n", | |
" T.Lambda(lambda x: x + torch.randn_like(x) * TF.rgb_to_grayscale(x).neg().add(1) * random.random() * 0.005),\n", | |
" T.Lambda(lambda x: F.interpolate(x, scale_factor=(random.random()*0.1+0.5), mode='bilinear', align_corners=False, recompute_scale_factor=False)),\n", | |
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.005),\n", | |
" T.RandomCrop(224),\n", | |
" T.Lambda(lambda x: x + torch.rand(1,3,1,1).div(64).to(device)),\n", | |
" T.Lambda(lambda x: x * (torch.rand(1,3,1,1).to(device)*0.1+0.95)),\n", | |
" T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n", | |
"])" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "F2yvfLyAViSq" | |
}, | |
"source": [ | |
"prompt = perceptor.encode_text(clip.tokenize([\"The Blood Rose, by Zdzislaw Beksinski\"]).to(device)).mean(0,True)\n", | |
"\n", | |
"palette = OpenImage(\"dawnbringers-8-color.png\").reshape(3,8).permute(1,0).to(device)\n", | |
"img_root = F.avg_pool2d(torch.randn(1,palette.shape[0],56,56),8)\n", | |
"img_root = TF.resize(img_root, 56, T.InterpolationMode.NEAREST)\n", | |
"img_root = quick_blur_2(img_root.to(device), 2)\n", | |
"img_root = img_root.to(device).requires_grad_(True)\n", | |
"optimizer = torch.optim.Adam([img_root], lr=1/32)\n", | |
"video_frames = None" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "nuyVpwXlalqS" | |
}, | |
"source": [ | |
"torch.set_grad_enabled(True)\n", | |
"\n", | |
"batch_size = 4\n", | |
"steps = 500\n", | |
"soft_clip_gain = 4\n", | |
"\n", | |
"def median2d(x,y,z=1):\n", | |
" x = F.pixel_shuffle(torch.quantile(F.pixel_unshuffle(x.permute(1,0,2,3),y), 0.5, 1, True).tile(1,y*y,1,1),y).permute(1,0,2,3)\n", | |
" if z > 1:\n", | |
" x = quick_blur_2(x,z)\n", | |
" return x\n", | |
"\n", | |
"def apply_pal(x,y,z=10):\n", | |
" x = softmax2d(x * z).permute(0,2,3,1) @ y\n", | |
" x = x.permute(0,3,1,2)\n", | |
" return x\n", | |
"\n", | |
"for i in range(steps):\n", | |
" # with torch.cuda.amp.autocast():\n", | |
" x = apply_pal(img_root, palette, 1) * 0.1 + apply_pal(img_root, palette, 10) * 0.9\n", | |
"\n", | |
" with torch.no_grad():\n", | |
" x_clamp = TF.resize(apply_pal(img_root, palette).clamp(0,1),224,T.InterpolationMode.NEAREST)\n", | |
" if video_frames == None:\n", | |
" video_frames = (x_clamp.permute(0,2,3,1).clamp(0,1)*255).byte().cpu()\n", | |
" else:\n", | |
" video_frames = torch.cat([video_frames, (x_clamp.permute(0,2,3,1).clamp(0,1)*255).byte().cpu()])\n", | |
"\n", | |
" if i < 100:\n", | |
" blur_scale = round(3 * min(max(1-i/100,0),1)**0.5)\n", | |
" # x_blur = F.avg_pool2d(x, blur_scale)\n", | |
" # x_blur = TF.resize(x_blur, 56, T.InterpolationMode.NEAREST)\n", | |
" # x_blur = TF.gaussian_blur(x_blur, 5)\n", | |
" x_blur = quick_blur_2(x, blur_scale)\n", | |
" else:\n", | |
" x_blur = x\n", | |
"\n", | |
" loss = 0.0\n", | |
" for l in range(batch_size):\n", | |
" x_aug = torch.cat([augments(x_blur) for _ in range(4)])\n", | |
" x_enc = perceptor.encode_image(x_aug.half())\n", | |
" loss += torch.cosine_similarity(x_enc, prompt, -1).pow(1).neg().add(1).mean() / batch_size\n", | |
"\n", | |
" if i < 250:\n", | |
" loss += (1-softmax2d(img_root).max(1,True).values).pow(2).mean() / 16\n", | |
" else:\n", | |
" loss += (1-softmax2d(img_root).max(1,True).values).pow(2).mean() / 8\n", | |
"\n", | |
" with torch.no_grad():\n", | |
" loss.backward()\n", | |
" img_root.data = torch.lerp(img_root, median2d(img_root, 8, 2), 0.02)\n", | |
" img_root.data = torch.lerp(img_root, median2d(img_root, 4), 0.03)\n", | |
" img_root.data = torch.lerp(img_root, median2d(img_root, 2), 0.05)\n", | |
" optimizer.step()\n", | |
" optimizer.zero_grad()\n", | |
"\n", | |
" if i % 50 == 0:\n", | |
" print(i, loss.item())\n", | |
" display(ToImage(apply_pal(img_root, palette).clamp(0,1)[0]).resize((224,224),0))\n", | |
"display(ToImage(apply_pal(img_root, palette).clamp(0,1)[0]).resize((224,224),0))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WOyqQm08eVJx" | |
}, | |
"source": [ | |
"torchvision.io.write_video(\"bloodrose.mp4\", video_frames, 15, options={'crf': '28'})" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment