Created
July 21, 2018 18:02
-
-
Save dpiponi/8a3ecade78e7df92191e7f3f50461459 to your computer and use it in GitHub Desktop.
Shallow image prior
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Trying to get a handle on "Deep Image Prior" | |
# at https://dmitryulyanov.github.io/deep_image_prior | |
# This is a toy version with a single purely linear | |
# convolution layer | |
# The goal is to start with an image with high res detail, | |
# corrupt a few bits, and then | |
# repair the corrupt bits using an a priori model | |
# that simply says "we can make the image from a lower resolution | |
# one using a transposed convolution". | |
# | |
# Because the optimisation needs to reconstruct high frequency | |
# detail from a low frequency image, it will try to find | |
# repeating patterns in the high frequencies. | |
# In this case, it figures out that the original image has | |
# a certain stipple pattern and uses that to generate the | |
# high res image. | |
# It can't represent high resolution information that doesn't fit | |
# the pattern. So the corrupted pixels get reduced. | |
# | |
# This is a super-simple model. So you can't expect it to do | |
# a good job. You're still going to have some image | |
# corruption but hopefully it's mitigated a bit. | |
from __future__ import absolute_import, division, print_function | |
import tensorflow as tf | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import tensorflow.contrib.eager as tfe | |
N = 64 | |
tf.enable_eager_execution() | |
# Generate a stipple pattern laid over a gradient | |
image = np.zeros((N, N), dtype=np.float32) | |
i = np.expand_dims(range(N), 0) | |
j = np.expand_dims(range(N), 1) | |
image = 0.125*(i%3)*(j%3)+0.5*i/float(N) | |
# Corrupt some pixels | |
M = 10 | |
image[np.random.randint(N, size=M), np.random.randint(N, size=M)] = 1 | |
image = np.expand_dims(image, 0) | |
image = np.expand_dims(image, 3) | |
# Start with random half-size image and random kernel | |
image0 = tfe.Variable(np.random.randn(1, N//2, N//2, 1)) | |
kernel = tfe.Variable(np.random.randn(5, 5, 1, 1)) | |
# Create full size image using transposed convolution | |
def generated_image(): | |
return tf.nn.conv2d_transpose(image0, kernel, [1, N, N, 1], | |
strides=[1, 2, 2, 1], padding='SAME') | |
# I don't want to get sidetracked by edge effects so I don't | |
# use the edge for training and don't render it either | |
def loss(): | |
return tf.reduce_sum((generated_image()-image)[0, 2:-2, 2:-2, 0]**2) | |
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) | |
grad = tfe.implicit_gradients(loss) | |
for step in range(10000): | |
optimizer.apply_gradients(grad()) | |
print(kernel) | |
result = np.array(generated_image()) | |
# Plot original image with corrupt bits | |
plt.subplot(131) | |
plt.imshow(image[0, 2:-2, 2:-2, 0], vmin=0, vmax=1, cmap='gray') | |
plt.title('corrupted') | |
# "Repaired" image | |
plt.subplot(132) | |
plt.imshow(result[0, 2:-2, 2:-2, 0], vmin=0, vmax=1, cmap='gray') | |
plt.title('repaired') | |
# The kernel used for repair. It should contain useful info | |
# about the original pattern. | |
plt.subplot(133) | |
plt.imshow(kernel[:, :, 0, 0]) | |
plt.title('kernel') | |
# I recommend expanding window to max size otherwise imshow() causes aliasing | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment