Skip to content

Instantly share code, notes, and snippets.

@PaulCreusy
Created November 7, 2024 16:01
Show Gist options
  • Save PaulCreusy/b4bbc6d04550950d9640d2edba9be11e to your computer and use it in GitHub Desktop.
Save PaulCreusy/b4bbc6d04550950d9640d2edba9be11e to your computer and use it in GitHub Desktop.
Neural Interpretation Diagram plot for Python and Tensorflow
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
def plot_neuron_level_nid(model):
"""
Plot a neuron-level Neural Interpretation Diagram (NID) for a TensorFlow model,
including biases.
Parameters:
model (tf.keras.Model): A TensorFlow Keras model instance.
"""
# Initialize directed graph
G = nx.DiGraph()
pos = {} # Dictionary to store positions of each neuron node for visualization
layer_positions = [] # Track horizontal position of each layer for organized plotting
# Vertical and horizontal spacing for neurons and biases
neuron_vertical_spacing = 1.0
layer_horizontal_spacing = 3.0
# Offset bias nodes slightly to the left of their neurons
bias_horizontal_offset = 0
bias_vertical_offset = 0.5
# Define x position for each layer
x = 0
# Process the input layer
input_shape = model.layers[0].input_shape[1]
y_positions = np.linspace(-(input_shape - 1) / 2,
(input_shape - 1) / 2, input_shape)
# Add input neurons as nodes in the graph
for neuron_index, y in enumerate(y_positions):
node_id = ("input", neuron_index)
G.add_node(node_id, label=f"Input_{neuron_index}")
pos[node_id] = (x, y)
# Update horizontal position for the next layer
layer_positions.append(x)
x += layer_horizontal_spacing
# Process hidden and output layers
for layer_index, layer in enumerate(model.layers):
if not hasattr(layer, 'weights') or not layer.weights:
# Skip layers without weights (e.g., Flatten or Dropout layers)
continue
weights, biases = layer.get_weights() # Get weights and biases of the layer
num_neurons = weights.shape[1]
# Position neurons vertically within the layer
y_positions = np.linspace(-(num_neurons - 1) / 2,
(num_neurons - 1) / 2, num_neurons)
# Add each neuron and its bias as nodes in the graph
for neuron_index, y in enumerate(y_positions):
# Unique ID for each neuron in the layer
neuron_id = (layer_index, neuron_index)
G.add_node(neuron_id, label=f"{layer.name}_{neuron_index}")
pos[neuron_id] = (x, y) # Assign position for visualization
# Add a bias node
bias_id = (layer_index, neuron_index, 'b')
G.add_node(bias_id, label=f"B_{layer_index}_{neuron_index}")
pos[bias_id] = (x - bias_horizontal_offset,
y - bias_vertical_offset)
# Connect the bias to the neuron with an edge showing the bias value
bias_value = biases[neuron_index]
G.add_edge(bias_id, neuron_id, weight=bias_value)
# Update horizontal position for the next layer
layer_positions.append(x)
x += layer_horizontal_spacing
# Add edges with weights between neurons in consecutive layers
for layer_index, layer in enumerate(model.layers):
if not hasattr(layer, 'weights') or not layer.weights:
continue
weights = layer.get_weights()[0]
# Determine source layer: input layer or previous hidden layer
if layer_index == 0:
source_neurons = [("input", i) for i in range(weights.shape[0])]
else:
source_neurons = [(layer_index - 1, i)
for i in range(weights.shape[0])]
# Target neurons are in the current layer
target_neurons = [(layer_index, j) for j in range(weights.shape[1])]
# Add edges with weights between source neurons and target neurons
for i, source_node in enumerate(source_neurons):
for j, target_node in enumerate(target_neurons):
weight = weights[i, j]
G.add_edge(source_node, target_node, weight=weight)
# Draw the network
plt.figure(figsize=(12, 8))
nx.draw_networkx_nodes(G, pos, node_size=300, node_color='skyblue')
nx.draw_networkx_labels(G, pos, font_size=8, font_color="black")
# Draw edges with weights
edges = G.edges(data=True)
weights = [data['weight'] for _, _, data in edges]
edge_colors = [np.abs(weight) for weight in weights]
nx.draw_networkx_edges(
G, pos, edgelist=edges, edge_color=edge_colors, edge_cmap=plt.cm.Blues, width=1.5)
# Add edge labels to show weights and biases
edge_labels = {(u, v): f'{d["weight"]:.2f}' for u, v, d in edges}
nx.draw_networkx_edge_labels(
G, pos, edge_labels=edge_labels, label_pos=0.25, font_size=6)
# Display the plot
plt.title(
"Neuron-Level Neural Interpretation Diagram (NID) of the Neural Network with Biases")
plt.axis("off")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment