Skip to content

Instantly share code, notes, and snippets.

@Voyz
Last active April 11, 2018 21:29
Show Gist options
  • Select an option

  • Save Voyz/6d5619c86b326432b49d7d9c08c88240 to your computer and use it in GitHub Desktop.

Select an option

Save Voyz/6d5619c86b326432b49d7d9c08c88240 to your computer and use it in GitHub Desktop.
visualise features for tensorflow
# feature - the tensor you want to visualise in shape [num_examples, dimX, dimY, channels]. If dims are provided, the feature tensor can be flattened. Eg shape: [20, 7, 7, 16]
# grid - the dimensions of a grid on which the features will be displayed. Expects a touple (x,y) where x * y have to equal to channels of the feature. Eg, for 16 channels, grid=[4,4]
# dims (optional) - dimensions of each feature. Eg. for feature [20, 7, 7, 16], dims=[7,7] (although in non-flattened tensor the dims can be ommited)
# max_outputs - number of examples to draw
def visualise_feature(feature, grid, dims=None, max_outputs=10):
with tf.name_scope('Visualize_filters') as scope:
original_shape = feature.get_shape().as_list()
original_len = len(original_shape)
if (original_len == 4 and dims is None):
dims = [original_shape[1], original_shape[2]]
dimA = dims[0]
dimB = dims[1]
if (original_len < 4):
assert dims is not None
channels = original_shape[2] if original_len == 3 else 1
feature = tf.reshape(feature, (-1, dimA, dimB, channels))
V = tf.slice(feature,(0,0,0,0),(-1,-1,-1,-1)) #V[0,...]
V = tf.reshape(V,(-1, dimA,dimB,1))
cx=grid[0]
cy=grid[1]
V = tf.reshape(V,(-1, dimA,dimB,cy,cx, 1))
# Reorder so the channels are in the first dimension, x and y follow.
V = tf.transpose(V, (0, 3, 1, 4, 2, 5))
V = tf.reshape(V,(-1,cy*dimB,cx*dimA,1))
tf.summary.image("Visualize_kernels", V, max_outputs=max_outputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment