Created
February 7, 2022 18:34
-
-
Save manzke/d443be4dcf3f78a32b81dd0eb913b234 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "_pRpIwnaOnb3" | |
}, | |
"source": [ | |
"# 🦙 **LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions**\n", | |
"\n", | |
"[[Project page](https://saic-mdal.github.io/lama-project/)] [[GitHub](https://github.com/saic-mdal/lama)] [[arXiv](https://arxiv.org/abs/2109.07161)] [[Supplementary](https://ashukha.com/projects/lama_21/lama_supmat_2021.pdf)] [[BibTeX](https://senya-ashukha.github.io/projects/lama_21/paper.txt)]\n", | |
"\n", | |
"<p align=\"center\" \"font-size:30px;\">\n", | |
"Our model generalizes surprisingly well to much higher resolutions (~2k❗️) than it saw during training (256x256), and achieves the excellent performance even in challenging scenarios, e.g. completion of periodic structures.\n", | |
"</p>\n", | |
"\n", | |
"# Try it yourself!👇\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"cellView": "form", | |
"id": "RwXRMaNHW4r5" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@title Run this sell to set everything up\n", | |
"print('\\n> Cloning the repo')\n", | |
"!git clone https://github.com/saic-mdal/lama.git\n", | |
"\n", | |
"print('\\n> Install dependencies')\n", | |
"!pip install -r lama/requirements.txt --quiet\n", | |
"!pip install wget --quiet\n", | |
"\n", | |
"print('\\n> Changing the dir to:')\n", | |
"% cd /content/lama\n", | |
"\n", | |
"print('\\n> Download the model')\n", | |
"!curl -L $(yadisk-direct https://disk.yandex.ru/d/ouP6l8VJ0HpMZg) -o big-lama.zip\n", | |
"!unzip big-lama.zip\n", | |
"\n", | |
"print('>fixing opencv')\n", | |
"!pip uninstall opencv-python-headless -y --quiet\n", | |
"!pip install opencv-python-headless==4.1.2.30 --quiet\n", | |
"\n", | |
"\n", | |
"print('\\n> Init mask-drawing code')\n", | |
"import base64, os\n", | |
"from IPython.display import HTML, Image\n", | |
"from google.colab.output import eval_js\n", | |
"from base64 import b64decode\n", | |
"import matplotlib.pyplot as plt\n", | |
"import numpy as np\n", | |
"import wget\n", | |
"from shutil import copyfile\n", | |
"import shutil\n", | |
"\n", | |
"\n", | |
"\n", | |
"canvas_html = \"\"\"\n", | |
"<style>\n", | |
".button {\n", | |
" background-color: #4CAF50;\n", | |
" border: none;\n", | |
" color: white;\n", | |
" padding: 15px 32px;\n", | |
" text-align: center;\n", | |
" text-decoration: none;\n", | |
" display: inline-block;\n", | |
" font-size: 16px;\n", | |
" margin: 4px 2px;\n", | |
" cursor: pointer;\n", | |
"}\n", | |
"</style>\n", | |
"<canvas1 width=%d height=%d>\n", | |
"</canvas1>\n", | |
"<canvas width=%d height=%d>\n", | |
"</canvas>\n", | |
"\n", | |
"<button class=\"button\">Finish</button>\n", | |
"<script>\n", | |
"var canvas = document.querySelector('canvas')\n", | |
"var ctx = canvas.getContext('2d')\n", | |
"\n", | |
"var canvas1 = document.querySelector('canvas1')\n", | |
"var ctx1 = canvas.getContext('2d')\n", | |
"\n", | |
"\n", | |
"ctx.strokeStyle = 'red';\n", | |
"\n", | |
"var img = new Image();\n", | |
"img.src = \"data:image/%s;charset=utf-8;base64,%s\";\n", | |
"console.log(img)\n", | |
"img.onload = function() {\n", | |
" ctx1.drawImage(img, 0, 0);\n", | |
"};\n", | |
"img.crossOrigin = 'Anonymous';\n", | |
"\n", | |
"ctx.clearRect(0, 0, canvas.width, canvas.height);\n", | |
"\n", | |
"ctx.lineWidth = %d\n", | |
"var button = document.querySelector('button')\n", | |
"var mouse = {x: 0, y: 0}\n", | |
"\n", | |
"canvas.addEventListener('mousemove', function(e) {\n", | |
" mouse.x = e.pageX - this.offsetLeft\n", | |
" mouse.y = e.pageY - this.offsetTop\n", | |
"})\n", | |
"canvas.onmousedown = ()=>{\n", | |
" ctx.beginPath()\n", | |
" ctx.moveTo(mouse.x, mouse.y)\n", | |
" canvas.addEventListener('mousemove', onPaint)\n", | |
"}\n", | |
"canvas.onmouseup = ()=>{\n", | |
" canvas.removeEventListener('mousemove', onPaint)\n", | |
"}\n", | |
"var onPaint = ()=>{\n", | |
" ctx.lineTo(mouse.x, mouse.y)\n", | |
" ctx.stroke()\n", | |
"}\n", | |
"\n", | |
"var data = new Promise(resolve=>{\n", | |
" button.onclick = ()=>{\n", | |
" resolve(canvas.toDataURL('image/png'))\n", | |
" }\n", | |
"})\n", | |
"</script>\n", | |
"\"\"\"\n", | |
"\n", | |
"def draw(imgm, filename='drawing.png', w=400, h=200, line_width=1):\n", | |
" display(HTML(canvas_html % (w, h, w,h, filename.split('.')[-1], imgm, line_width)))\n", | |
" data = eval_js(\"data\")\n", | |
" binary = b64decode(data.split(',')[1])\n", | |
" with open(filename, 'wb') as f:\n", | |
" f.write(binary)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "23WaUHiJeyBO" | |
}, | |
"source": [ | |
"<center>\n", | |
"<h1 style=\"font-size:10vw\"><b>Predefined photo</b>: uncomment any line\n", | |
"<br>\n", | |
"<b>Local file</b>: leave the <tt>fname = None</tt></h1>\n", | |
"</center>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "IFIDDD4IhPXd" | |
}, | |
"outputs": [], | |
"source": [ | |
"fname = None\n", | |
"#fname = '0ca8e4f3f4150937082add7dbdd149196dc05ec2.jpg'\n", | |
"# fname = 'https://ic.pics.livejournal.com/mostovoy/28566193/1224276/1224276_original.jpg' # <-in the example\n", | |
"# fname = 'https://raw.githubusercontent.com/senya-ashukha/senya-ashukha.github.io/master/images/1010286.jpeg'\n", | |
"# fname = 'https://raw.githubusercontent.com/senya-ashukha/senya-ashukha.github.io/master/images/1010287.jpeg'\n", | |
"# fname = \"https://raw.githubusercontent.com/senya-ashukha/senya-ashukha.github.io/master/images/alex.jpg\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"cellView": "code", | |
"id": "-VZWySTMeGDM" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@title Draw a Mask, Press Finish, Wait for Inpainting\n", | |
"\n", | |
"#if fname is None:\n", | |
"# from google.colab import files\n", | |
"# files = files.upload()\n", | |
"# fname = list(files.keys())[0]\n", | |
"#else:\n", | |
"# fname = wget.download(fname)\n", | |
"\n", | |
"#shutil.rmtree('./data_for_prediction', ignore_errors=True)\n", | |
"#! mkdir data_for_prediction\n", | |
"\n", | |
"#copyfile(fname, f'./data_for_prediction/{fname}')\n", | |
"#os.remove(fname)\n", | |
"fname = f'./data_for_prediction/{fname}'\n", | |
"\n", | |
"image64 = base64.b64encode(open(fname, 'rb').read())\n", | |
"image64 = image64.decode('utf-8')\n", | |
"\n", | |
"print(f'Will use {fname} for inpainting')\n", | |
"img = np.array(plt.imread(f'{fname}')[:,:,:3])\n", | |
"\n", | |
"#draw(image64, filename=f\"./{fname.split('.')[1]}_mask.png\", w=img.shape[1], h=img.shape[0], line_width=0.04*img.shape[1])\n", | |
"#@title Show a masked image and save a mask\n", | |
"import matplotlib.pyplot as plt\n", | |
"plt.rcParams[\"figure.figsize\"] = (15,5)\n", | |
"plt.rcParams['figure.dpi'] = 200\n", | |
"plt.subplot(131)\n", | |
"with_mask = np.array(plt.imread(f\"./{fname.split('.')[1]}_mask.png\")[:,:,:3])\n", | |
"plt.imsave(f\"./{fname.split('.')[1]}_mask_gray.png\", with_mask, cmap='gray')\n", | |
"with_mask = np.array(plt.imread(f\"./{fname.split('.')[1]}_mask_gray.png\")[:,:,:3])\n", | |
"mask = (with_mask[:,:,0]==1)*(with_mask[:,:,1]==0)*(with_mask[:,:,2]==0)\n", | |
"plt.imshow(with_mask, cmap='gray')\n", | |
"plt.axis('off')\n", | |
"plt.title('mask')\n", | |
"#plt.imsave(f\"./{fname.split('.')[1]}_mask.png\",mask, cmap='gray')\n", | |
"\n", | |
"plt.subplot(132)\n", | |
"img = np.array(plt.imread(f'{fname}')[:,:,:3])\n", | |
"plt.imshow(img)\n", | |
"plt.axis('off')\n", | |
"plt.title('img')\n", | |
"\n", | |
"plt.subplot(133)\n", | |
"img = np.array((1-mask.reshape(mask.shape[0], mask.shape[1], -1))*plt.imread(fname)[:,:,:3])\n", | |
"_=plt.imshow(img)\n", | |
"_=plt.axis('off')\n", | |
"_=plt.title('img * mask')\n", | |
"plt.show()\n", | |
"\n", | |
"print('Run inpainting')\n", | |
"if '.jpeg' in fname:\n", | |
" !PYTHONPATH=. TORCH_HOME=$(pwd) python3 bin/predict.py model.path=$(pwd)/big-lama indir=$(pwd)/data_for_prediction outdir=/content/output dataset.img_suffix=.jpeg > /dev/null\n", | |
"elif '.jpg' in fname:\n", | |
" !PYTHONPATH=. TORCH_HOME=$(pwd) python3 bin/predict.py model.path=$(pwd)/big-lama indir=$(pwd)/data_for_prediction outdir=/content/output dataset.img_suffix=.jpg > /dev/null\n", | |
"elif '.png' in fname:\n", | |
" !PYTHONPATH=. TORCH_HOME=$(pwd) python3 bin/predict.py model.path=$(pwd)/big-lama indir=$(pwd)/data_for_prediction outdir=/content/output dataset.img_suffix=.png > /dev/null\n", | |
"else:\n", | |
" print(f'Error: unknown suffix .{fname.split(\".\")[-1]} use [.png, .jpeg, .jpg]')\n", | |
"\n", | |
"plt.rcParams['figure.dpi'] = 200\n", | |
"plt.imshow(plt.imread(f\"/content/output/{fname.split('.')[1].split('/')[2]}_mask.png\"))\n", | |
"_=plt.axis('off')\n", | |
"_=plt.title('inpainting result')\n", | |
"plt.show()\n", | |
"fname = None" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "Ug9vfkBHqxzZ" | |
}, | |
"outputs": [], | |
"source": [ | |
"" | |
] | |
} | |
], | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"collapsed_sections": [], | |
"name": "LaMa-inpainting.ipynb", | |
"provenance": [] | |
}, | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment