Last active
November 14, 2024 18:25
-
-
Save yxlao/ef50416011b9587835ac752aa3ce3530 to your computer and use it in GitHub Desktop.
TensorFlow Convolution Gradients
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Demostrating how to compute the gradients for convolution with: | |
tf.nn.conv2d | |
tf.nn.conv2d_backprop_input | |
tf.nn.conv2d_backprop_filter | |
tf.nn.conv2d_transpose | |
This is the scripts for this answer: https://stackoverflow.com/a/44350789/1255535 | |
""" | |
import tensorflow as tf | |
import numpy as np | |
import scipy.signal | |
def tf_rot180(w): | |
""" | |
Roate by 180 degrees | |
""" | |
return tf.reverse(w, axis=[0, 1]) | |
def tf_pad_to_full_conv2d(x, w_size): | |
""" | |
Pad x, such that using a 'VALID' convolution in tensorflow is the same | |
as using a 'FULL' convolution. See | |
http://deeplearning.net/software/theano/library/tensor/nnet/conv.html#theano.tensor.nnet.conv2d | |
for description of 'FULL' convolution. | |
""" | |
return tf.pad(x, [[0, 0], | |
[w_size - 1, w_size - 1], | |
[w_size - 1, w_size - 1], | |
[0, 0]]) | |
def tf_NHWC_to_HWIO(out): | |
""" | |
Converts [batch, in_height, in_width, in_channels] | |
to [filter_height, filter_width, in_channels, out_channels] | |
""" | |
return tf.transpose(out, perm=[1, 2, 0, 3]) | |
# sizes, fixed strides, in_channel, out_channel be 1 for now | |
x_size = 4 | |
w_size = 3 # use an odd number here | |
x_shape = (1, x_size, x_size, 1) | |
w_shape = (w_size, w_size, 1, 1) | |
out_shape = (1, x_size - w_size + 1, x_size - w_size + 1, 1) | |
strides = (1, 1, 1, 1) | |
# numpy value | |
x_np = np.random.randint(10, size=x_shape) | |
w_np = np.random.randint(10, size=w_shape) | |
out_scale_np = np.random.randint(10, size=out_shape) | |
# tf forward | |
x = tf.constant(x_np, dtype=tf.float32) | |
w = tf.constant(w_np, dtype=tf.float32) | |
out = tf.nn.conv2d(input=x, filter=w, strides=strides, padding='VALID') | |
out_scale = tf.constant(out_scale_np, dtype=tf.float32) | |
f = tf.reduce_sum(tf.multiply(out, out_scale)) | |
# tf backward | |
d_out = tf.gradients(f, out)[0] | |
# 4 different ways to compute d_x | |
d_x = tf.gradients(f, x)[0] | |
d_x_manual = tf.nn.conv2d(input=tf_pad_to_full_conv2d(d_out, w_size), | |
filter=tf_rot180(w), | |
strides=strides, | |
padding='VALID') | |
d_x_backprop_input = tf.nn.conv2d_backprop_input(input_sizes=x_shape, | |
filter=w, | |
out_backprop=d_out, | |
strides=strides, | |
padding='VALID') | |
d_x_transpose = tf.nn.conv2d_transpose(value=d_out, | |
filter=w, | |
output_shape=x_shape, | |
strides=strides, | |
padding='VALID') | |
# 3 different ways to compute d_w | |
d_w = tf.gradients(f, w)[0] | |
d_w_manual = tf_NHWC_to_HWIO(tf.nn.conv2d(input=x, | |
filter=tf_NHWC_to_HWIO(d_out), | |
strides=strides, | |
padding='VALID')) | |
d_w_backprop_filter = tf.nn.conv2d_backprop_filter(input=x, | |
filter_sizes=w_shape, | |
out_backprop=d_out, | |
strides=strides, | |
padding='VALID') | |
# run | |
with tf.Session() as sess: | |
np.testing.assert_allclose(sess.run(d_x), sess.run(d_x_manual)) | |
np.testing.assert_allclose(sess.run(d_x), sess.run(d_x_backprop_input)) | |
np.testing.assert_allclose(sess.run(d_x), sess.run(d_x_transpose)) | |
np.testing.assert_allclose(sess.run(d_w), sess.run(d_w_manual)) | |
np.testing.assert_allclose(sess.run(d_w), sess.run(d_w_backprop_filter)) | |
""" | |
Get the same results using numpy / scipy | |
""" | |
def rot180(x): | |
return np.flipud(np.fliplr(x)) | |
def conv2d(x, w, mode='full', boundary='fill', fillvalue=0): | |
""" | |
2d convolution without rot180, scipy's conv2d rotates | |
the filter by 180 degress. | |
""" | |
return scipy.signal.convolve2d(x, rot180(w), | |
mode=mode, | |
boundary=boundary, | |
fillvalue=fillvalue) | |
# convert from 4d to 2d | |
x_np = x_np.squeeze() | |
w_np = w_np.squeeze() | |
d_out_np = out_scale_np.squeeze() | |
# compute gradient manually | |
out_np = conv2d(x_np, w_np, mode='valid') | |
d_x_np = conv2d(d_out_np, rot180(w_np), mode='full') | |
d_w_np = conv2d(x_np, d_out_np, mode='valid') | |
# run | |
with tf.Session() as sess: | |
np.testing.assert_allclose(sess.run(d_x).squeeze(), d_x_np) | |
np.testing.assert_allclose(sess.run(d_w).squeeze(), d_w_np) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment