-
-
Save nbeuchat/091c458327c39a84ba06e8686c76dfd5 to your computer and use it in GitHub Desktop.
Draw a neural network diagram with matplotlib!
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Modified on Mon Oct 10 23:29:41 2016 | |
@author: craffel, edited by nbeuchat (Nicolas Beuchat) | |
""" | |
import matplotlib.pyplot as plt | |
def draw_neural_net(layer_sizes, ax=None, left=.1, right=.9, bottom=.1, top=.9,color='w'): | |
''' | |
Draw a neural network cartoon using matplotilb. | |
:usage: | |
>>> fig = plt.figure(figsize=(12, 12)) | |
>>> draw_neural_net(fig.gca(), .1, .9, .1, .9, [4, 7, 2]) | |
:parameters: | |
- layer_sizes : list of int | |
List of layer sizes, including input and output dimensionality | |
- ax : matplotlib.axes.AxesSubplot | |
The axes on which to plot the cartoon (get e.g. by plt.gca()). Default: gca | |
- left : float | |
The center of the leftmost node(s) will be placed here. Default = 0.1 | |
- right : float | |
The center of the rightmost node(s) will be placed here. Default = 0.9 | |
- bottom : float | |
The center of the bottommost node(s) will be placed here. Default = 0.1 | |
- top : float | |
The center of the topmost node(s) will be placed here. Default = 0.9 | |
- color: string or array of string or array of array of int | |
The color of the nodes (layer by layer) | |
Example: | |
color='k' -> black neurons | |
color=['r','k','b'] -> red input layer, black hidden layer, blue output layer | |
color=['r',[0.3,0.23,0.6],'g'] -> can use RGB value as well | |
''' | |
n_layers = len(layer_sizes) | |
v_spacing = (top - bottom)/float(max(layer_sizes)) | |
h_spacing = (right - left)/float(len(layer_sizes) - 1) | |
c = color | |
if ax is None: | |
ax = plt.gca() | |
# Nodes | |
for n, layer_size in enumerate(layer_sizes): | |
layer_top = v_spacing*(layer_size - 1)/2. + (top + bottom)/2. | |
for m in range(layer_size): | |
if len(color) > 1: | |
c = color[n] | |
circle = plt.Circle((n*h_spacing + left, layer_top - m*v_spacing), v_spacing/4., | |
color=c, ec='k', zorder=4) | |
ax.add_artist(circle) | |
# Edges | |
for n, (layer_size_a, layer_size_b) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])): | |
layer_top_a = v_spacing*(layer_size_a - 1)/2. + (top + bottom)/2. | |
layer_top_b = v_spacing*(layer_size_b - 1)/2. + (top + bottom)/2. | |
for m in range(layer_size_a): | |
for o in range(layer_size_b): | |
line = plt.Line2D([n*h_spacing + left, (n + 1)*h_spacing + left], | |
[layer_top_a - m*v_spacing, layer_top_b - o*v_spacing], c='k') | |
ax.add_artist(line) | |
# Beautify the axes | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment