Skip to content

Instantly share code, notes, and snippets.

@craffel
Created January 10, 2015 04:59
Show Gist options
  • Save craffel/2d727968c3aaebd10359 to your computer and use it in GitHub Desktop.
Save craffel/2d727968c3aaebd10359 to your computer and use it in GitHub Desktop.
Draw a neural network diagram with matplotlib!
import matplotlib.pyplot as plt
def draw_neural_net(ax, left, right, bottom, top, layer_sizes):
'''
Draw a neural network cartoon using matplotilb.
:usage:
>>> fig = plt.figure(figsize=(12, 12))
>>> draw_neural_net(fig.gca(), .1, .9, .1, .9, [4, 7, 2])
:parameters:
- ax : matplotlib.axes.AxesSubplot
The axes on which to plot the cartoon (get e.g. by plt.gca())
- left : float
The center of the leftmost node(s) will be placed here
- right : float
The center of the rightmost node(s) will be placed here
- bottom : float
The center of the bottommost node(s) will be placed here
- top : float
The center of the topmost node(s) will be placed here
- layer_sizes : list of int
List of layer sizes, including input and output dimensionality
'''
n_layers = len(layer_sizes)
v_spacing = (top - bottom)/float(max(layer_sizes))
h_spacing = (right - left)/float(len(layer_sizes) - 1)
# Nodes
for n, layer_size in enumerate(layer_sizes):
layer_top = v_spacing*(layer_size - 1)/2. + (top + bottom)/2.
for m in xrange(layer_size):
circle = plt.Circle((n*h_spacing + left, layer_top - m*v_spacing), v_spacing/4.,
color='w', ec='k', zorder=4)
ax.add_artist(circle)
# Edges
for n, (layer_size_a, layer_size_b) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
layer_top_a = v_spacing*(layer_size_a - 1)/2. + (top + bottom)/2.
layer_top_b = v_spacing*(layer_size_b - 1)/2. + (top + bottom)/2.
for m in xrange(layer_size_a):
for o in xrange(layer_size_b):
line = plt.Line2D([n*h_spacing + left, (n + 1)*h_spacing + left],
[layer_top_a - m*v_spacing, layer_top_b - o*v_spacing], c='k')
ax.add_artist(line)
@endolith
Copy link

I don't think people are notified when you @ them on Gists, and @ljhuang2017 doesn't have any contact information. Did anyone get their code working? Can you post it as your own Gist if so?

@SebastianAvalos
Copy link

Many thanks for the script!

@dvgodoy
Copy link

dvgodoy commented Mar 18, 2018

I was able to successfully run @ljhuang2017 code and posted on a new gist
The final result looks like this:
nn_diagram

@chieh-neu
Copy link

Thank you for the code, saved me a lot of time drawing it myself.

@gsampath127
Copy link

How can i have labels for coefficients and intercepts. I need for demonstration.

Labels(a1,b1,c1,d1 ) etc.. like this

@hbaromega
Copy link

@craffel Thanks for the wonderful code. However, the layers get added from top to bottom, not from left to right. This makes me find difficulties in putting separate colors for the nodes in the input, hidden layer, and output! Could that be done? Also, some annotations for the layers?

@ImMamey
Copy link

ImMamey commented Oct 30, 2023

Halu, I wanted to use these codes, but since I use python3 I needed to change some stuff to make it work.

Thus, I did some minor changes to @dvgodoy gist code, so now it supports python3 and np.array()

link

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment