Last active
May 14, 2020 18:44
-
-
Save RaphaelMeudec/31b7bba0b972ec6ec80ed131a59c5b3f to your computer and use it in GitHub Desktop.
Visualize convolutional kernels with Tensorflow 2.0
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import tensorflow as tf | |
| # Layer name to inspect | |
| layer_name = 'block3_conv1' | |
| epochs = 100 | |
| step_size = 1. | |
| filter_index = 0 | |
| # Create a connection between the input and the target layer | |
| model = tf.keras.applications.vgg16.VGG16(weights='imagenet', include_top=True) | |
| submodel = tf.keras.models.Model([model.inputs[0]], [model.get_layer(layer_name).output]) | |
| # Initiate random noise | |
| input_img_data = np.random.random((1, 224, 224, 3)) | |
| input_img_data = (input_img_data - 0.5) * 20 + 128. | |
| # Cast random noise from np.float64 to tf.float32 Variable | |
| input_img_data = tf.Variable(tf.cast(input_img_data, tf.float32)) | |
| # Iterate gradient ascents | |
| for _ in range(epochs): | |
| with tf.GradientTape() as tape: | |
| outputs = submodel(input_img_data) | |
| loss_value = tf.reduce_mean(outputs[:, :, :, filter_index]) | |
| grads = tape.gradient(loss_value, input_img_data) | |
| normalized_grads = grads / (tf.sqrt(tf.reduce_mean(tf.square(grads))) + 1e-5) | |
| input_img_data.assign_add(normalized_grads * step_size) |
Author
At the end of the loop, input_img_data is a 4D tensor holding the generated image. What you want to do is convert it to a numpy with .numpy() and visualizing it with matplotlib for example
Author
For matplotlib to perform well, you want either to normalize values between 0 and 1, or convert the image to int. As the warnings says, it's clipping values from 0-255 range to 0-1 which makes the image so poor
@konradsemsch I solved it by converting input_img_data this way:
input_img_data = input_img_data.numpy().astype(np.uint8)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment

Excuse my stupid question, but how do we actually see/save the image(s)?