Skip to content

Instantly share code, notes, and snippets.

@dpiponi
Created July 21, 2018 18:02
Show Gist options
  • Save dpiponi/8a3ecade78e7df92191e7f3f50461459 to your computer and use it in GitHub Desktop.
Save dpiponi/8a3ecade78e7df92191e7f3f50461459 to your computer and use it in GitHub Desktop.
Shallow image prior
# Trying to get a handle on "Deep Image Prior"
# at https://dmitryulyanov.github.io/deep_image_prior
# This is a toy version with a single purely linear
# convolution layer
# The goal is to start with an image with high res detail,
# corrupt a few bits, and then
# repair the corrupt bits using an a priori model
# that simply says "we can make the image from a lower resolution
# one using a transposed convolution".
#
# Because the optimisation needs to reconstruct high frequency
# detail from a low frequency image, it will try to find
# repeating patterns in the high frequencies.
# In this case, it figures out that the original image has
# a certain stipple pattern and uses that to generate the
# high res image.
# It can't represent high resolution information that doesn't fit
# the pattern. So the corrupted pixels get reduced.
#
# This is a super-simple model. So you can't expect it to do
# a good job. You're still going to have some image
# corruption but hopefully it's mitigated a bit.
from __future__ import absolute_import, division, print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow.contrib.eager as tfe
N = 64
tf.enable_eager_execution()
# Generate a stipple pattern laid over a gradient
image = np.zeros((N, N), dtype=np.float32)
i = np.expand_dims(range(N), 0)
j = np.expand_dims(range(N), 1)
image = 0.125*(i%3)*(j%3)+0.5*i/float(N)
# Corrupt some pixels
M = 10
image[np.random.randint(N, size=M), np.random.randint(N, size=M)] = 1
image = np.expand_dims(image, 0)
image = np.expand_dims(image, 3)
# Start with random half-size image and random kernel
image0 = tfe.Variable(np.random.randn(1, N//2, N//2, 1))
kernel = tfe.Variable(np.random.randn(5, 5, 1, 1))
# Create full size image using transposed convolution
def generated_image():
return tf.nn.conv2d_transpose(image0, kernel, [1, N, N, 1],
strides=[1, 2, 2, 1], padding='SAME')
# I don't want to get sidetracked by edge effects so I don't
# use the edge for training and don't render it either
def loss():
return tf.reduce_sum((generated_image()-image)[0, 2:-2, 2:-2, 0]**2)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
grad = tfe.implicit_gradients(loss)
for step in range(10000):
optimizer.apply_gradients(grad())
print(kernel)
result = np.array(generated_image())
# Plot original image with corrupt bits
plt.subplot(131)
plt.imshow(image[0, 2:-2, 2:-2, 0], vmin=0, vmax=1, cmap='gray')
plt.title('corrupted')
# "Repaired" image
plt.subplot(132)
plt.imshow(result[0, 2:-2, 2:-2, 0], vmin=0, vmax=1, cmap='gray')
plt.title('repaired')
# The kernel used for repair. It should contain useful info
# about the original pattern.
plt.subplot(133)
plt.imshow(kernel[:, :, 0, 0])
plt.title('kernel')
# I recommend expanding window to max size otherwise imshow() causes aliasing
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment