Skip to content

Instantly share code, notes, and snippets.

Created February 7, 2022 18:13
Show Gist options
  • Save manzke/2a27deb2bd1dca1dbf3187284ee69d7b to your computer and use it in GitHub Desktop.
Save manzke/2a27deb2bd1dca1dbf3187284ee69d7b to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Image impainting",
"provenance": [],
"collapsed_sections": [],
"machine_shape": "hm",
"background_execution": "on"
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
"accelerator": "GPU",
"language_info": {
"name": "python"
"cells": [
"cell_type": "code",
"source": [
"%cd /content\n",
"!git clone"
"metadata": {
"id": "BMf9MDSCwfxp"
"execution_count": null,
"outputs": []
"cell_type": "code",
"source": [
"import torch\n",
"from torch import nn\n",
"#from .modules import Conv2dBlock, Concat\n",
"class SkipEncoderDecoder(nn.Module):\n",
" def __init__(self, input_depth, num_channels_down = [128] * 5, num_channels_up = [128] * 5, num_channels_skip = [128] * 5):\n",
" super(SkipEncoderDecoder, self).__init__()\n",
" self.model = nn.Sequential()\n",
" model_tmp = self.model\n",
" for i in range(len(num_channels_down)):\n",
" deeper = nn.Sequential()\n",
" skip = nn.Sequential()\n",
" if num_channels_skip[i] != 0:\n",
" model_tmp.add_module(str(len(model_tmp) + 1), Concat(1, skip, deeper))\n",
" else:\n",
" model_tmp.add_module(str(len(model_tmp) + 1), deeper)\n",
" \n",
" model_tmp.add_module(str(len(model_tmp) + 1), nn.BatchNorm2d(num_channels_skip[i] + (num_channels_up[i + 1] if i < (len(num_channels_down) - 1) else num_channels_down[i])))\n",
" if num_channels_skip[i] != 0:\n",
" skip.add_module(str(len(skip) + 1), Conv2dBlock(input_depth, num_channels_skip[i], 1, bias = False))\n",
" \n",
" deeper.add_module(str(len(deeper) + 1), Conv2dBlock(input_depth, num_channels_down[i], 3, 2, bias = False))\n",
" deeper.add_module(str(len(deeper) + 1), Conv2dBlock(num_channels_down[i], num_channels_down[i], 3, bias = False))\n",
" deeper_main = nn.Sequential()\n",
" if i == len(num_channels_down) - 1:\n",
" k = num_channels_down[i]\n",
" else:\n",
" deeper.add_module(str(len(deeper) + 1), deeper_main)\n",
" k = num_channels_up[i + 1]\n",
" deeper.add_module(str(len(deeper) + 1), nn.Upsample(scale_factor = 2, mode = 'nearest'))\n",
" model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_skip[i] + k, num_channels_up[i], 3, 1, bias = False))\n",
" model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_up[i], num_channels_up[i], 1, bias = False))\n",
" input_depth = num_channels_down[i]\n",
" model_tmp = deeper_main\n",
" self.model.add_module(str(len(self.model) + 1), nn.Conv2d(num_channels_up[0], 3, 1, bias = True))\n",
" self.model.add_module(str(len(self.model) + 1), nn.Sigmoid())\n",
" \n",
" def forward(self, x):\n",
" return self.model(x)\n",
"def input_noise(INPUT_DEPTH, spatial_size, scale = 1./10):\n",
" shape = [1, INPUT_DEPTH, spatial_size[0], spatial_size[1]]\n",
" return torch.rand(*shape) * scale"
"metadata": {
"id": "qLJ-Px9j5572"
"execution_count": null,
"outputs": []
"cell_type": "code",
"source": [
"import torch\n",
"from torch import nn\n",
"import numpy as np\n",
"class DepthwiseSeperableConv2d(nn.Module):\n",
" def __init__(self, input_channels, output_channels, **kwargs):\n",
" super(DepthwiseSeperableConv2d, self).__init__()\n",
" self.depthwise = nn.Conv2d(input_channels, input_channels, groups = input_channels, **kwargs)\n",
" self.pointwise = nn.Conv2d(input_channels, output_channels, kernel_size = 1)\n",
" def forward(self, x):\n",
" x = self.depthwise(x)\n",
" x = self.pointwise(x)\n",
" return x\n",
"class Conv2dBlock(nn.Module):\n",
" def __init__(self, in_channels, out_channels, kernel_size, stride = 1, bias = False):\n",
" super(Conv2dBlock, self).__init__()\n",
" self.model = nn.Sequential(\n",
" nn.ReflectionPad2d(int((kernel_size - 1) / 2)),\n",
" DepthwiseSeperableConv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = 0, bias = bias),\n",
" nn.BatchNorm2d(out_channels),\n",
" nn.LeakyReLU(0.2)\n",
" )\n",
" def forward(self, x):\n",
" return self.model(x)\n",
"class Concat(nn.Module):\n",
" def __init__(self, dim, *args):\n",
" super(Concat, self).__init__()\n",
" self.dim = dim\n",
" for idx, module in enumerate(args):\n",
" self.add_module(str(idx), module)\n",
" def forward(self, input):\n",
" inputs = []\n",
" for module in self._modules.values():\n",
" inputs.append(module(input))\n",
" inputs_shapes2 = [x.shape[2] for x in inputs]\n",
" inputs_shapes3 = [x.shape[3] for x in inputs] \n",
" if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)):\n",
" inputs_ = inputs\n",
" else:\n",
" target_shape2 = min(inputs_shapes2)\n",
" target_shape3 = min(inputs_shapes3)\n",
" inputs_ = []\n",
" for inp in inputs: \n",
" diff2 = (inp.size(2) - target_shape2) // 2 \n",
" diff3 = (inp.size(3) - target_shape3) // 2 \n",
" inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3])\n",
" return, dim=self.dim)\n",
" def __len__(self):\n",
" return len(self._modules)"
"metadata": {
"id": "YWUJ4JLO5_Uy"
"execution_count": null,
"outputs": []
"cell_type": "code",
"source": [
"from torch import optim\n",
"from import tqdm\n",
"from helper import *\n",
"#from model.generator import SkipEncoderDecoder, input_noise\n",
"def remove_watermark(image_path, mask_path, max_dim, reg_noise, input_depth, lr, show_step, training_steps, tqdm_length = 100):\n",
" DTYPE = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor\n",
" if not torch.cuda.is_available():\n",
" print('\\nSetting device to \"cpu\", since torch is not built with \"cuda\" support...')\n",
" print('It is recommended to use GPU if possible...')\n",
" image_np, mask_np = preprocess_images(image_path, mask_path, max_dim)\n",
" print('Building the model...')\n",
" generator = SkipEncoderDecoder(\n",
" input_depth,\n",
" num_channels_down = [128] * 5,\n",
" num_channels_up = [128] * 5,\n",
" num_channels_skip = [128] * 5\n",
" ).type(DTYPE)\n",
" objective = torch.nn.MSELoss().type(DTYPE)\n",
" optimizer = optim.Adam(generator.parameters(), lr)\n",
" image_var = np_to_torch_array(image_np).type(DTYPE)\n",
" mask_var = np_to_torch_array(mask_np).type(DTYPE)\n",
" generator_input = input_noise(input_depth, image_np.shape[1:]).type(DTYPE)\n",
" generator_input_saved = generator_input.detach().clone()\n",
" noise = generator_input.detach().clone()\n",
" print('\\nStarting training...\\n')\n",
" progress_bar = tqdm(range(training_steps), desc = 'Completed', ncols = tqdm_length)\n",
" for step in progress_bar:\n",
" optimizer.zero_grad()\n",
" generator_input = generator_input_saved\n",
" if reg_noise > 0:\n",
" generator_input = generator_input_saved + (noise.normal_() * reg_noise)\n",
" \n",
" output = generator(generator_input)\n",
" \n",
" loss = objective(output * mask_var, image_var * mask_var)\n",
" loss.backward()\n",
" if step % show_step == 0:\n",
" output_image = torch_to_np_array(output)\n",
" visualize_sample(image_np, output_image, nrow = 2, size_factor = 10)\n",
" \n",
" progress_bar.set_postfix(Loss = loss.item())\n",
" \n",
" optimizer.step()\n",
" \n",
" output_image = torch_to_np_array(output)\n",
" visualize_sample(output_image, nrow = 1, size_factor = 10)\n",
" pil_image = Image.fromarray((output_image.transpose(1, 2, 0) * 255.0).astype('uint8'))\n",
" output_path = image_path.split('/')[-1].split('.')[-2] + '-output.jpg'\n",
" print(f'\\nSaving final output image to: \"{output_path}\"\\n')\n",
"metadata": {
"id": "G5-r4OA96EaQ"
"execution_count": null,
"outputs": []
"cell_type": "code",
"source": [
"import numpy as np\n",
"from PIL import Image\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"from torchvision.utils import make_grid\n",
"def pil_to_np_array(pil_image):\n",
" ar = np.array(pil_image)\n",
" if len(ar.shape) == 3:\n",
" ar = ar.transpose(2,0,1)\n",
" else:\n",
" ar = ar[None, ...]\n",
" return ar.astype(np.float32) / 255.\n",
"def np_to_torch_array(np_array):\n",
" return torch.from_numpy(np_array)[None, :]\n",
"def torch_to_np_array(torch_array):\n",
" return torch_array.detach().cpu().numpy()[0]\n",
"def read_image(path, image_size = -1):\n",
" pil_image =\n",
" return pil_image\n",
"def crop_image(image, crop_factor = 64):\n",
" shape = (image.size[0] - image.size[0] % crop_factor, image.size[1] - image.size[1] % crop_factor)\n",
" bbox = [int((image.shape[0] - shape[0])/2), int((image.shape[1] - shape[1])/2), int((image.shape[0] + shape[0])/2), int((image.shape[1] + shape[1])/2)]\n",
" return image.crop(bbox)\n",
"def get_image_grid(images, nrow = 3):\n",
" torch_images = [torch.from_numpy(x) for x in images]\n",
" grid = make_grid(torch_images, nrow)\n",
" return grid.numpy()\n",
" \n",
"def visualize_sample(*images_np, nrow = 3, size_factor = 10):\n",
" c = max(x.shape[0] for x in images_np)\n",
" images_np = [x if (x.shape[0] == c) else np.concatenate([x, x, x], axis = 0) for x in images_np]\n",
" grid = get_image_grid(images_np, nrow)\n",
" plt.figure(figsize = (len(images_np) + size_factor, 12 + size_factor))\n",
" plt.axis('off')\n",
" plt.imshow(grid.transpose(1, 2, 0))\n",
"def max_dimension_resize(image_pil, mask_pil, max_dim):\n",
" w, h = image_pil.size\n",
" aspect_ratio = w / h\n",
" if w > max_dim:\n",
" h = int((h / w) * max_dim)\n",
" w = max_dim\n",
" elif h > max_dim:\n",
" w = int((w / h) * max_dim)\n",
" h = max_dim\n",
" return image_pil.resize((w, h)), mask_pil.resize((w, h))\n",
"def preprocess_images(image_path, mask_path, max_dim):\n",
" image_pil = read_image(image_path).convert('RGB')\n",
" mask_pil = read_image(mask_path).convert('RGB')\n",
" image_pil, mask_pil = max_dimension_resize(image_pil, mask_pil, max_dim)\n",
" image_np = pil_to_np_array(image_pil)\n",
" mask_np = pil_to_np_array(mask_pil)\n",
" print('Visualizing mask overlap...')\n",
" visualize_sample(image_np, mask_np, image_np * mask_np, nrow = 3, size_factor = 10)\n",
" return image_np, mask_np"
"metadata": {
"id": "XGEQsXnc6ICq"
"execution_count": null,
"outputs": []
"cell_type": "code",
"source": [
"import argparse\n",
"import os\n",
"#from api import remove_watermark\n",
"DTYPE = torch.cuda.FloatTensor\n",
"image_path = os.path.join('image.jpg')\n",
"mask_path = os.path.join('watermark-shutterstock.png')\n",
"max_dim = 1024\n",
"show_step = 50\n",
"reg_noise = 0.03\n",
"input_depth = 256\n",
"lr = 0.01\n",
"training_steps = 5000\n",
"tqdm_length = 900\n",
"metadata": {
"id": "bMI9C9CV6PUl"
"execution_count": null,
"outputs": []
"cell_type": "code",
"source": [
" remove_watermark(\n",
" image_path = image_path,\n",
" mask_path = mask_path,\n",
" max_dim = max_dim,\n",
" show_step = show_step,\n",
" reg_noise = reg_noise,\n",
" input_depth = input_depth,\n",
" lr = lr,\n",
" training_steps = training_steps,\n",
" tqdm_length = tqdm_length\n",
"metadata": {
"id": "HWY9lfT-6OMJ"
"execution_count": null,
"outputs": []
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment