Created
November 7, 2024 16:01
-
-
Save PaulCreusy/b4bbc6d04550950d9640d2edba9be11e to your computer and use it in GitHub Desktop.
Neural Interpretation Diagram plot for Python and Tensorflow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import networkx as nx | |
import numpy as np | |
def plot_neuron_level_nid(model): | |
""" | |
Plot a neuron-level Neural Interpretation Diagram (NID) for a TensorFlow model, | |
including biases. | |
Parameters: | |
model (tf.keras.Model): A TensorFlow Keras model instance. | |
""" | |
# Initialize directed graph | |
G = nx.DiGraph() | |
pos = {} # Dictionary to store positions of each neuron node for visualization | |
layer_positions = [] # Track horizontal position of each layer for organized plotting | |
# Vertical and horizontal spacing for neurons and biases | |
neuron_vertical_spacing = 1.0 | |
layer_horizontal_spacing = 3.0 | |
# Offset bias nodes slightly to the left of their neurons | |
bias_horizontal_offset = 0 | |
bias_vertical_offset = 0.5 | |
# Define x position for each layer | |
x = 0 | |
# Process the input layer | |
input_shape = model.layers[0].input_shape[1] | |
y_positions = np.linspace(-(input_shape - 1) / 2, | |
(input_shape - 1) / 2, input_shape) | |
# Add input neurons as nodes in the graph | |
for neuron_index, y in enumerate(y_positions): | |
node_id = ("input", neuron_index) | |
G.add_node(node_id, label=f"Input_{neuron_index}") | |
pos[node_id] = (x, y) | |
# Update horizontal position for the next layer | |
layer_positions.append(x) | |
x += layer_horizontal_spacing | |
# Process hidden and output layers | |
for layer_index, layer in enumerate(model.layers): | |
if not hasattr(layer, 'weights') or not layer.weights: | |
# Skip layers without weights (e.g., Flatten or Dropout layers) | |
continue | |
weights, biases = layer.get_weights() # Get weights and biases of the layer | |
num_neurons = weights.shape[1] | |
# Position neurons vertically within the layer | |
y_positions = np.linspace(-(num_neurons - 1) / 2, | |
(num_neurons - 1) / 2, num_neurons) | |
# Add each neuron and its bias as nodes in the graph | |
for neuron_index, y in enumerate(y_positions): | |
# Unique ID for each neuron in the layer | |
neuron_id = (layer_index, neuron_index) | |
G.add_node(neuron_id, label=f"{layer.name}_{neuron_index}") | |
pos[neuron_id] = (x, y) # Assign position for visualization | |
# Add a bias node | |
bias_id = (layer_index, neuron_index, 'b') | |
G.add_node(bias_id, label=f"B_{layer_index}_{neuron_index}") | |
pos[bias_id] = (x - bias_horizontal_offset, | |
y - bias_vertical_offset) | |
# Connect the bias to the neuron with an edge showing the bias value | |
bias_value = biases[neuron_index] | |
G.add_edge(bias_id, neuron_id, weight=bias_value) | |
# Update horizontal position for the next layer | |
layer_positions.append(x) | |
x += layer_horizontal_spacing | |
# Add edges with weights between neurons in consecutive layers | |
for layer_index, layer in enumerate(model.layers): | |
if not hasattr(layer, 'weights') or not layer.weights: | |
continue | |
weights = layer.get_weights()[0] | |
# Determine source layer: input layer or previous hidden layer | |
if layer_index == 0: | |
source_neurons = [("input", i) for i in range(weights.shape[0])] | |
else: | |
source_neurons = [(layer_index - 1, i) | |
for i in range(weights.shape[0])] | |
# Target neurons are in the current layer | |
target_neurons = [(layer_index, j) for j in range(weights.shape[1])] | |
# Add edges with weights between source neurons and target neurons | |
for i, source_node in enumerate(source_neurons): | |
for j, target_node in enumerate(target_neurons): | |
weight = weights[i, j] | |
G.add_edge(source_node, target_node, weight=weight) | |
# Draw the network | |
plt.figure(figsize=(12, 8)) | |
nx.draw_networkx_nodes(G, pos, node_size=300, node_color='skyblue') | |
nx.draw_networkx_labels(G, pos, font_size=8, font_color="black") | |
# Draw edges with weights | |
edges = G.edges(data=True) | |
weights = [data['weight'] for _, _, data in edges] | |
edge_colors = [np.abs(weight) for weight in weights] | |
nx.draw_networkx_edges( | |
G, pos, edgelist=edges, edge_color=edge_colors, edge_cmap=plt.cm.Blues, width=1.5) | |
# Add edge labels to show weights and biases | |
edge_labels = {(u, v): f'{d["weight"]:.2f}' for u, v, d in edges} | |
nx.draw_networkx_edge_labels( | |
G, pos, edge_labels=edge_labels, label_pos=0.25, font_size=6) | |
# Display the plot | |
plt.title( | |
"Neuron-Level Neural Interpretation Diagram (NID) of the Neural Network with Biases") | |
plt.axis("off") | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment