Skip to content

Instantly share code, notes, and snippets.

@manzke
Created February 7, 2022 18:34
Show Gist options
  • Save manzke/d443be4dcf3f78a32b81dd0eb913b234 to your computer and use it in GitHub Desktop.
Save manzke/d443be4dcf3f78a32b81dd0eb913b234 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "_pRpIwnaOnb3"
},
"source": [
"# 🦙 **LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions**\n",
"\n",
"[[Project page](https://saic-mdal.github.io/lama-project/)] [[GitHub](https://github.com/saic-mdal/lama)] [[arXiv](https://arxiv.org/abs/2109.07161)] [[Supplementary](https://ashukha.com/projects/lama_21/lama_supmat_2021.pdf)] [[BibTeX](https://senya-ashukha.github.io/projects/lama_21/paper.txt)]\n",
"\n",
"<p align=\"center\" \"font-size:30px;\">\n",
"Our model generalizes surprisingly well to much higher resolutions (~2k❗️) than it saw during training (256x256), and achieves the excellent performance even in challenging scenarios, e.g. completion of periodic structures.\n",
"</p>\n",
"\n",
"# Try it yourself!👇\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "RwXRMaNHW4r5"
},
"outputs": [],
"source": [
"#@title Run this sell to set everything up\n",
"print('\\n> Cloning the repo')\n",
"!git clone https://github.com/saic-mdal/lama.git\n",
"\n",
"print('\\n> Install dependencies')\n",
"!pip install -r lama/requirements.txt --quiet\n",
"!pip install wget --quiet\n",
"\n",
"print('\\n> Changing the dir to:')\n",
"% cd /content/lama\n",
"\n",
"print('\\n> Download the model')\n",
"!curl -L $(yadisk-direct https://disk.yandex.ru/d/ouP6l8VJ0HpMZg) -o big-lama.zip\n",
"!unzip big-lama.zip\n",
"\n",
"print('>fixing opencv')\n",
"!pip uninstall opencv-python-headless -y --quiet\n",
"!pip install opencv-python-headless==4.1.2.30 --quiet\n",
"\n",
"\n",
"print('\\n> Init mask-drawing code')\n",
"import base64, os\n",
"from IPython.display import HTML, Image\n",
"from google.colab.output import eval_js\n",
"from base64 import b64decode\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import wget\n",
"from shutil import copyfile\n",
"import shutil\n",
"\n",
"\n",
"\n",
"canvas_html = \"\"\"\n",
"<style>\n",
".button {\n",
" background-color: #4CAF50;\n",
" border: none;\n",
" color: white;\n",
" padding: 15px 32px;\n",
" text-align: center;\n",
" text-decoration: none;\n",
" display: inline-block;\n",
" font-size: 16px;\n",
" margin: 4px 2px;\n",
" cursor: pointer;\n",
"}\n",
"</style>\n",
"<canvas1 width=%d height=%d>\n",
"</canvas1>\n",
"<canvas width=%d height=%d>\n",
"</canvas>\n",
"\n",
"<button class=\"button\">Finish</button>\n",
"<script>\n",
"var canvas = document.querySelector('canvas')\n",
"var ctx = canvas.getContext('2d')\n",
"\n",
"var canvas1 = document.querySelector('canvas1')\n",
"var ctx1 = canvas.getContext('2d')\n",
"\n",
"\n",
"ctx.strokeStyle = 'red';\n",
"\n",
"var img = new Image();\n",
"img.src = \"data:image/%s;charset=utf-8;base64,%s\";\n",
"console.log(img)\n",
"img.onload = function() {\n",
" ctx1.drawImage(img, 0, 0);\n",
"};\n",
"img.crossOrigin = 'Anonymous';\n",
"\n",
"ctx.clearRect(0, 0, canvas.width, canvas.height);\n",
"\n",
"ctx.lineWidth = %d\n",
"var button = document.querySelector('button')\n",
"var mouse = {x: 0, y: 0}\n",
"\n",
"canvas.addEventListener('mousemove', function(e) {\n",
" mouse.x = e.pageX - this.offsetLeft\n",
" mouse.y = e.pageY - this.offsetTop\n",
"})\n",
"canvas.onmousedown = ()=>{\n",
" ctx.beginPath()\n",
" ctx.moveTo(mouse.x, mouse.y)\n",
" canvas.addEventListener('mousemove', onPaint)\n",
"}\n",
"canvas.onmouseup = ()=>{\n",
" canvas.removeEventListener('mousemove', onPaint)\n",
"}\n",
"var onPaint = ()=>{\n",
" ctx.lineTo(mouse.x, mouse.y)\n",
" ctx.stroke()\n",
"}\n",
"\n",
"var data = new Promise(resolve=>{\n",
" button.onclick = ()=>{\n",
" resolve(canvas.toDataURL('image/png'))\n",
" }\n",
"})\n",
"</script>\n",
"\"\"\"\n",
"\n",
"def draw(imgm, filename='drawing.png', w=400, h=200, line_width=1):\n",
" display(HTML(canvas_html % (w, h, w,h, filename.split('.')[-1], imgm, line_width)))\n",
" data = eval_js(\"data\")\n",
" binary = b64decode(data.split(',')[1])\n",
" with open(filename, 'wb') as f:\n",
" f.write(binary)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "23WaUHiJeyBO"
},
"source": [
"<center>\n",
"<h1 style=\"font-size:10vw\"><b>Predefined photo</b>: uncomment any line\n",
"<br>\n",
"<b>Local file</b>: leave the <tt>fname = None</tt></h1>\n",
"</center>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IFIDDD4IhPXd"
},
"outputs": [],
"source": [
"fname = None\n",
"#fname = '0ca8e4f3f4150937082add7dbdd149196dc05ec2.jpg'\n",
"# fname = 'https://ic.pics.livejournal.com/mostovoy/28566193/1224276/1224276_original.jpg' # <-in the example\n",
"# fname = 'https://raw.githubusercontent.com/senya-ashukha/senya-ashukha.github.io/master/images/1010286.jpeg'\n",
"# fname = 'https://raw.githubusercontent.com/senya-ashukha/senya-ashukha.github.io/master/images/1010287.jpeg'\n",
"# fname = \"https://raw.githubusercontent.com/senya-ashukha/senya-ashukha.github.io/master/images/alex.jpg\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "code",
"id": "-VZWySTMeGDM"
},
"outputs": [],
"source": [
"#@title Draw a Mask, Press Finish, Wait for Inpainting\n",
"\n",
"#if fname is None:\n",
"# from google.colab import files\n",
"# files = files.upload()\n",
"# fname = list(files.keys())[0]\n",
"#else:\n",
"# fname = wget.download(fname)\n",
"\n",
"#shutil.rmtree('./data_for_prediction', ignore_errors=True)\n",
"#! mkdir data_for_prediction\n",
"\n",
"#copyfile(fname, f'./data_for_prediction/{fname}')\n",
"#os.remove(fname)\n",
"fname = f'./data_for_prediction/{fname}'\n",
"\n",
"image64 = base64.b64encode(open(fname, 'rb').read())\n",
"image64 = image64.decode('utf-8')\n",
"\n",
"print(f'Will use {fname} for inpainting')\n",
"img = np.array(plt.imread(f'{fname}')[:,:,:3])\n",
"\n",
"#draw(image64, filename=f\"./{fname.split('.')[1]}_mask.png\", w=img.shape[1], h=img.shape[0], line_width=0.04*img.shape[1])\n",
"#@title Show a masked image and save a mask\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams[\"figure.figsize\"] = (15,5)\n",
"plt.rcParams['figure.dpi'] = 200\n",
"plt.subplot(131)\n",
"with_mask = np.array(plt.imread(f\"./{fname.split('.')[1]}_mask.png\")[:,:,:3])\n",
"plt.imsave(f\"./{fname.split('.')[1]}_mask_gray.png\", with_mask, cmap='gray')\n",
"with_mask = np.array(plt.imread(f\"./{fname.split('.')[1]}_mask_gray.png\")[:,:,:3])\n",
"mask = (with_mask[:,:,0]==1)*(with_mask[:,:,1]==0)*(with_mask[:,:,2]==0)\n",
"plt.imshow(with_mask, cmap='gray')\n",
"plt.axis('off')\n",
"plt.title('mask')\n",
"#plt.imsave(f\"./{fname.split('.')[1]}_mask.png\",mask, cmap='gray')\n",
"\n",
"plt.subplot(132)\n",
"img = np.array(plt.imread(f'{fname}')[:,:,:3])\n",
"plt.imshow(img)\n",
"plt.axis('off')\n",
"plt.title('img')\n",
"\n",
"plt.subplot(133)\n",
"img = np.array((1-mask.reshape(mask.shape[0], mask.shape[1], -1))*plt.imread(fname)[:,:,:3])\n",
"_=plt.imshow(img)\n",
"_=plt.axis('off')\n",
"_=plt.title('img * mask')\n",
"plt.show()\n",
"\n",
"print('Run inpainting')\n",
"if '.jpeg' in fname:\n",
" !PYTHONPATH=. TORCH_HOME=$(pwd) python3 bin/predict.py model.path=$(pwd)/big-lama indir=$(pwd)/data_for_prediction outdir=/content/output dataset.img_suffix=.jpeg > /dev/null\n",
"elif '.jpg' in fname:\n",
" !PYTHONPATH=. TORCH_HOME=$(pwd) python3 bin/predict.py model.path=$(pwd)/big-lama indir=$(pwd)/data_for_prediction outdir=/content/output dataset.img_suffix=.jpg > /dev/null\n",
"elif '.png' in fname:\n",
" !PYTHONPATH=. TORCH_HOME=$(pwd) python3 bin/predict.py model.path=$(pwd)/big-lama indir=$(pwd)/data_for_prediction outdir=/content/output dataset.img_suffix=.png > /dev/null\n",
"else:\n",
" print(f'Error: unknown suffix .{fname.split(\".\")[-1]} use [.png, .jpeg, .jpg]')\n",
"\n",
"plt.rcParams['figure.dpi'] = 200\n",
"plt.imshow(plt.imread(f\"/content/output/{fname.split('.')[1].split('/')[2]}_mask.png\"))\n",
"_=plt.axis('off')\n",
"_=plt.title('inpainting result')\n",
"plt.show()\n",
"fname = None"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ug9vfkBHqxzZ"
},
"outputs": [],
"source": [
""
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "LaMa-inpainting.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment