Created
February 7, 2022 18:13
-
-
Save manzke/2a27deb2bd1dca1dbf3187284ee69d7b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Image impainting", | |
"provenance": [], | |
"collapsed_sections": [], | |
"machine_shape": "hm", | |
"background_execution": "on" | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU", | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%cd /content\n", | |
"!git clone https://github.com/braindotai/Watermark-Removal-Pytorch.git" | |
], | |
"metadata": { | |
"id": "BMf9MDSCwfxp" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#model\n", | |
"import torch\n", | |
"from torch import nn\n", | |
"#from .modules import Conv2dBlock, Concat\n", | |
"\n", | |
"class SkipEncoderDecoder(nn.Module):\n", | |
" def __init__(self, input_depth, num_channels_down = [128] * 5, num_channels_up = [128] * 5, num_channels_skip = [128] * 5):\n", | |
" super(SkipEncoderDecoder, self).__init__()\n", | |
"\n", | |
" self.model = nn.Sequential()\n", | |
" model_tmp = self.model\n", | |
"\n", | |
" for i in range(len(num_channels_down)):\n", | |
"\n", | |
" deeper = nn.Sequential()\n", | |
" skip = nn.Sequential()\n", | |
"\n", | |
" if num_channels_skip[i] != 0:\n", | |
" model_tmp.add_module(str(len(model_tmp) + 1), Concat(1, skip, deeper))\n", | |
" else:\n", | |
" model_tmp.add_module(str(len(model_tmp) + 1), deeper)\n", | |
" \n", | |
" model_tmp.add_module(str(len(model_tmp) + 1), nn.BatchNorm2d(num_channels_skip[i] + (num_channels_up[i + 1] if i < (len(num_channels_down) - 1) else num_channels_down[i])))\n", | |
"\n", | |
" if num_channels_skip[i] != 0:\n", | |
" skip.add_module(str(len(skip) + 1), Conv2dBlock(input_depth, num_channels_skip[i], 1, bias = False))\n", | |
" \n", | |
" deeper.add_module(str(len(deeper) + 1), Conv2dBlock(input_depth, num_channels_down[i], 3, 2, bias = False))\n", | |
" deeper.add_module(str(len(deeper) + 1), Conv2dBlock(num_channels_down[i], num_channels_down[i], 3, bias = False))\n", | |
"\n", | |
" deeper_main = nn.Sequential()\n", | |
"\n", | |
" if i == len(num_channels_down) - 1:\n", | |
" k = num_channels_down[i]\n", | |
" else:\n", | |
" deeper.add_module(str(len(deeper) + 1), deeper_main)\n", | |
" k = num_channels_up[i + 1]\n", | |
"\n", | |
" deeper.add_module(str(len(deeper) + 1), nn.Upsample(scale_factor = 2, mode = 'nearest'))\n", | |
"\n", | |
" model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_skip[i] + k, num_channels_up[i], 3, 1, bias = False))\n", | |
" model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_up[i], num_channels_up[i], 1, bias = False))\n", | |
"\n", | |
" input_depth = num_channels_down[i]\n", | |
" model_tmp = deeper_main\n", | |
"\n", | |
" self.model.add_module(str(len(self.model) + 1), nn.Conv2d(num_channels_up[0], 3, 1, bias = True))\n", | |
" self.model.add_module(str(len(self.model) + 1), nn.Sigmoid())\n", | |
" \n", | |
" def forward(self, x):\n", | |
" return self.model(x)\n", | |
"\n", | |
"\n", | |
"def input_noise(INPUT_DEPTH, spatial_size, scale = 1./10):\n", | |
" shape = [1, INPUT_DEPTH, spatial_size[0], spatial_size[1]]\n", | |
" return torch.rand(*shape) * scale" | |
], | |
"metadata": { | |
"id": "qLJ-Px9j5572" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#generator\n", | |
"import torch\n", | |
"from torch import nn\n", | |
"import numpy as np\n", | |
"\n", | |
"class DepthwiseSeperableConv2d(nn.Module):\n", | |
" def __init__(self, input_channels, output_channels, **kwargs):\n", | |
" super(DepthwiseSeperableConv2d, self).__init__()\n", | |
"\n", | |
" self.depthwise = nn.Conv2d(input_channels, input_channels, groups = input_channels, **kwargs)\n", | |
" self.pointwise = nn.Conv2d(input_channels, output_channels, kernel_size = 1)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.depthwise(x)\n", | |
" x = self.pointwise(x)\n", | |
"\n", | |
" return x\n", | |
"\n", | |
"class Conv2dBlock(nn.Module):\n", | |
" def __init__(self, in_channels, out_channels, kernel_size, stride = 1, bias = False):\n", | |
" super(Conv2dBlock, self).__init__()\n", | |
"\n", | |
" self.model = nn.Sequential(\n", | |
" nn.ReflectionPad2d(int((kernel_size - 1) / 2)),\n", | |
" DepthwiseSeperableConv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = 0, bias = bias),\n", | |
" nn.BatchNorm2d(out_channels),\n", | |
" nn.LeakyReLU(0.2)\n", | |
" )\n", | |
"\n", | |
" def forward(self, x):\n", | |
" return self.model(x)\n", | |
"\n", | |
"class Concat(nn.Module):\n", | |
" def __init__(self, dim, *args):\n", | |
" super(Concat, self).__init__()\n", | |
" self.dim = dim\n", | |
"\n", | |
" for idx, module in enumerate(args):\n", | |
" self.add_module(str(idx), module)\n", | |
"\n", | |
" def forward(self, input):\n", | |
" inputs = []\n", | |
" for module in self._modules.values():\n", | |
" inputs.append(module(input))\n", | |
"\n", | |
" inputs_shapes2 = [x.shape[2] for x in inputs]\n", | |
" inputs_shapes3 = [x.shape[3] for x in inputs] \n", | |
"\n", | |
" if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)):\n", | |
" inputs_ = inputs\n", | |
" else:\n", | |
" target_shape2 = min(inputs_shapes2)\n", | |
" target_shape3 = min(inputs_shapes3)\n", | |
"\n", | |
" inputs_ = []\n", | |
" for inp in inputs: \n", | |
" diff2 = (inp.size(2) - target_shape2) // 2 \n", | |
" diff3 = (inp.size(3) - target_shape3) // 2 \n", | |
" inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3])\n", | |
"\n", | |
" return torch.cat(inputs_, dim=self.dim)\n", | |
"\n", | |
" def __len__(self):\n", | |
" return len(self._modules)" | |
], | |
"metadata": { | |
"id": "YWUJ4JLO5_Uy" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#api\n", | |
"from torch import optim\n", | |
"from tqdm.auto import tqdm\n", | |
"from helper import *\n", | |
"#from model.generator import SkipEncoderDecoder, input_noise\n", | |
"\n", | |
"def remove_watermark(image_path, mask_path, max_dim, reg_noise, input_depth, lr, show_step, training_steps, tqdm_length = 100):\n", | |
" DTYPE = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor\n", | |
" if not torch.cuda.is_available():\n", | |
" print('\\nSetting device to \"cpu\", since torch is not built with \"cuda\" support...')\n", | |
" print('It is recommended to use GPU if possible...')\n", | |
"\n", | |
" image_np, mask_np = preprocess_images(image_path, mask_path, max_dim)\n", | |
"\n", | |
" print('Building the model...')\n", | |
" generator = SkipEncoderDecoder(\n", | |
" input_depth,\n", | |
" num_channels_down = [128] * 5,\n", | |
" num_channels_up = [128] * 5,\n", | |
" num_channels_skip = [128] * 5\n", | |
" ).type(DTYPE)\n", | |
"\n", | |
" objective = torch.nn.MSELoss().type(DTYPE)\n", | |
" optimizer = optim.Adam(generator.parameters(), lr)\n", | |
"\n", | |
" image_var = np_to_torch_array(image_np).type(DTYPE)\n", | |
" mask_var = np_to_torch_array(mask_np).type(DTYPE)\n", | |
"\n", | |
" generator_input = input_noise(input_depth, image_np.shape[1:]).type(DTYPE)\n", | |
"\n", | |
" generator_input_saved = generator_input.detach().clone()\n", | |
" noise = generator_input.detach().clone()\n", | |
"\n", | |
" print('\\nStarting training...\\n')\n", | |
"\n", | |
" progress_bar = tqdm(range(training_steps), desc = 'Completed', ncols = tqdm_length)\n", | |
"\n", | |
" for step in progress_bar:\n", | |
" optimizer.zero_grad()\n", | |
" generator_input = generator_input_saved\n", | |
"\n", | |
" if reg_noise > 0:\n", | |
" generator_input = generator_input_saved + (noise.normal_() * reg_noise)\n", | |
" \n", | |
" output = generator(generator_input)\n", | |
" \n", | |
" loss = objective(output * mask_var, image_var * mask_var)\n", | |
" loss.backward()\n", | |
"\n", | |
" if step % show_step == 0:\n", | |
" output_image = torch_to_np_array(output)\n", | |
" visualize_sample(image_np, output_image, nrow = 2, size_factor = 10)\n", | |
" \n", | |
" progress_bar.set_postfix(Loss = loss.item())\n", | |
" \n", | |
" optimizer.step()\n", | |
" \n", | |
" output_image = torch_to_np_array(output)\n", | |
" visualize_sample(output_image, nrow = 1, size_factor = 10)\n", | |
"\n", | |
" pil_image = Image.fromarray((output_image.transpose(1, 2, 0) * 255.0).astype('uint8'))\n", | |
"\n", | |
" output_path = image_path.split('/')[-1].split('.')[-2] + '-output.jpg'\n", | |
" print(f'\\nSaving final output image to: \"{output_path}\"\\n')\n", | |
"\n", | |
" pil_image.save(output_path)" | |
], | |
"metadata": { | |
"id": "G5-r4OA96EaQ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#helper\n", | |
"import numpy as np\n", | |
"from PIL import Image\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"import torch\n", | |
"from torchvision.utils import make_grid\n", | |
"\n", | |
"def pil_to_np_array(pil_image):\n", | |
" ar = np.array(pil_image)\n", | |
" if len(ar.shape) == 3:\n", | |
" ar = ar.transpose(2,0,1)\n", | |
" else:\n", | |
" ar = ar[None, ...]\n", | |
" return ar.astype(np.float32) / 255.\n", | |
"\n", | |
"def np_to_torch_array(np_array):\n", | |
" return torch.from_numpy(np_array)[None, :]\n", | |
"\n", | |
"def torch_to_np_array(torch_array):\n", | |
" return torch_array.detach().cpu().numpy()[0]\n", | |
"\n", | |
"def read_image(path, image_size = -1):\n", | |
" pil_image = Image.open(path)\n", | |
" return pil_image\n", | |
"\n", | |
"def crop_image(image, crop_factor = 64):\n", | |
" shape = (image.size[0] - image.size[0] % crop_factor, image.size[1] - image.size[1] % crop_factor)\n", | |
" bbox = [int((image.shape[0] - shape[0])/2), int((image.shape[1] - shape[1])/2), int((image.shape[0] + shape[0])/2), int((image.shape[1] + shape[1])/2)]\n", | |
" return image.crop(bbox)\n", | |
"\n", | |
"def get_image_grid(images, nrow = 3):\n", | |
" torch_images = [torch.from_numpy(x) for x in images]\n", | |
" grid = make_grid(torch_images, nrow)\n", | |
" return grid.numpy()\n", | |
" \n", | |
"def visualize_sample(*images_np, nrow = 3, size_factor = 10):\n", | |
" c = max(x.shape[0] for x in images_np)\n", | |
" images_np = [x if (x.shape[0] == c) else np.concatenate([x, x, x], axis = 0) for x in images_np]\n", | |
" grid = get_image_grid(images_np, nrow)\n", | |
" plt.figure(figsize = (len(images_np) + size_factor, 12 + size_factor))\n", | |
" plt.axis('off')\n", | |
" plt.imshow(grid.transpose(1, 2, 0))\n", | |
" plt.show()\n", | |
"\n", | |
"def max_dimension_resize(image_pil, mask_pil, max_dim):\n", | |
" w, h = image_pil.size\n", | |
" aspect_ratio = w / h\n", | |
" if w > max_dim:\n", | |
" h = int((h / w) * max_dim)\n", | |
" w = max_dim\n", | |
" elif h > max_dim:\n", | |
" w = int((w / h) * max_dim)\n", | |
" h = max_dim\n", | |
" return image_pil.resize((w, h)), mask_pil.resize((w, h))\n", | |
"\n", | |
"def preprocess_images(image_path, mask_path, max_dim):\n", | |
" image_pil = read_image(image_path).convert('RGB')\n", | |
" mask_pil = read_image(mask_path).convert('RGB')\n", | |
"\n", | |
" image_pil, mask_pil = max_dimension_resize(image_pil, mask_pil, max_dim)\n", | |
"\n", | |
" image_np = pil_to_np_array(image_pil)\n", | |
" mask_np = pil_to_np_array(mask_pil)\n", | |
"\n", | |
" print('Visualizing mask overlap...')\n", | |
"\n", | |
" visualize_sample(image_np, mask_np, image_np * mask_np, nrow = 3, size_factor = 10)\n", | |
"\n", | |
" return image_np, mask_np" | |
], | |
"metadata": { | |
"id": "XGEQsXnc6ICq" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import argparse\n", | |
"import os\n", | |
"#from api import remove_watermark\n", | |
"\n", | |
"DTYPE = torch.cuda.FloatTensor\n", | |
"\n", | |
"image_path = os.path.join('image.jpg')\n", | |
"mask_path = os.path.join('watermark-shutterstock.png')\n", | |
"max_dim = 1024\n", | |
"show_step = 50\n", | |
"reg_noise = 0.03\n", | |
"input_depth = 256\n", | |
"lr = 0.01\n", | |
"training_steps = 5000\n", | |
"tqdm_length = 900\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "bMI9C9CV6PUl" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
" remove_watermark(\n", | |
" image_path = image_path,\n", | |
" mask_path = mask_path,\n", | |
" max_dim = max_dim,\n", | |
" show_step = show_step,\n", | |
" reg_noise = reg_noise,\n", | |
" input_depth = input_depth,\n", | |
" lr = lr,\n", | |
" training_steps = training_steps,\n", | |
" tqdm_length = tqdm_length\n", | |
")" | |
], | |
"metadata": { | |
"id": "HWY9lfT-6OMJ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment