-
-
Save ahoereth/f9b0305ce33b67fc219a87391ed1f41e to your computer and use it in GitHub Desktop.
Tensorflow: visualize convolutional features (conv1) in Cifar10 model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def kernel_images(name, x, pad=1, max_images=-1, summarize=True): | |
"""Create image summaries of 2d convolution kernel weights.""" | |
def factorization(n): | |
"""Calculates kernel grid dimensions.""" | |
for i in range(int(sqrt(float(n))), 0, -1): | |
if n % i == 0: | |
return i, n // i | |
with tf.name_scope(name): | |
# Normalize values and pad grid. | |
low, high = tf.reduce_min(x), tf.reduce_max(x) | |
x = (x - low) / (high - low) | |
x = tf.pad(x, ((pad, pad), (pad, pad), (0, 0), (0, 0))) | |
# Organize grid. | |
r, c, chan_in, chan_out = x.get_shape() | |
grid_r, grid_c = factorization(chan_out.value) | |
x = tf.transpose(x, (3, 0, 1, 2)) # Move chan_out to front. | |
x = tf.reshape(x, tf.stack([grid_c, r * grid_r, c, chan_in])) # y | |
x = tf.transpose(x, (0, 2, 1, 3)) # Switch x and y. | |
x = tf.reshape(x, tf.stack([1, c * grid_c, r * grid_r, chan_in])) # x | |
x = tf.transpose(x, (0, 2, 1, 3)) # batch, height, width, channels | |
# Last dimension may only contain a maximum of 4 channels. | |
while x.get_shape()[3] > 4: | |
a, b = tf.split(x, 2, axis=3) | |
x = tf.concat([a, b], axis=0) | |
# Last dimension may only contain 1, 3 or 4 channels. | |
_, a, b, channels = x.get_shape() | |
if channels == 2: | |
x = tf.concat([x, tf.zeros((1, a, b, 1))], axis=3) | |
# Summarize requested amount of images. | |
if summarize: | |
max_images = x.get_shape()[0] if max_images == -1 else max_images | |
tf.summary.image('kernels', x, max_images) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment