Last active
September 4, 2017 15:35
-
-
Save Kautenja/8161314e6563c8581dbd52ab3c53981f to your computer and use it in GitHub Desktop.
Solving the XOR function using a deep net in `keras`.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Solving XOR With A Deep Net\n", | |
"\n", | |
"$$\\veebar(x, y) = x'y + xy'$$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 83, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>$x$</th>\n", | |
" <th>$y$</th>\n", | |
" <th>XOR</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" $x$ $y$ XOR\n", | |
"0 0 0 0\n", | |
"1 0 1 1\n", | |
"2 1 0 1\n", | |
"3 1 1 0" | |
] | |
}, | |
"execution_count": 83, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from pandas import DataFrame\n", | |
"# generate a design matrix for XOR\n", | |
"truthtable = [[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 0]]\n", | |
"columns = ['$x$', '$y$', 'XOR']\n", | |
"df = DataFrame(truthtable, columns=columns)\n", | |
"df" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 84, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"X = df[['$x$', '$y$']]\n", | |
"Y = df[['XOR']]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Network Design\n", | |
"\n", | |
"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 85, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# reproducable random number seeds\n", | |
"## keras\n", | |
"from numpy.random import seed\n", | |
"seed(100)\n", | |
"## TensorFlow\n", | |
"from tensorflow import set_random_seed\n", | |
"set_random_seed(100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 86, | |
"metadata": { | |
"collapsed": true, | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from keras.models import Sequential\n", | |
"from keras.layers import Dense" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 87, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# build the network in the graph above\n", | |
"model = Sequential()\n", | |
"model.name = 'XOR'\n", | |
"model.add(Dense(2, input_dim=2, activation='tanh'))\n", | |
"model.add(Dense(2, activation='tanh'))\n", | |
"model.add(Dense(1, activation='sigmoid'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 88, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 89, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"dense_22 (Dense) (None, 2) 6 \n", | |
"_________________________________________________________________\n", | |
"dense_23 (Dense) (None, 2) 6 \n", | |
"_________________________________________________________________\n", | |
"dense_24 (Dense) (None, 1) 3 \n", | |
"=================================================================\n", | |
"Total params: 15\n", | |
"Trainable params: 15\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
], | |
"source": [ | |
"model.summary()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 90, | |
"metadata": { | |
"collapsed": true, | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# fit the model, 4600 epochs was the minimum needed to reach a loss of ~0.020\n", | |
"# turning off logging speeds the process up a _lot_.\n", | |
"_ = model.fit(X.values, Y.values, epochs=4600, batch_size=4, verbose=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 91, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"4/4 [==============================] - 0s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"[('loss', 0.0084844892844557762), ('acc', 1.0)]" | |
] | |
}, | |
"execution_count": 91, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# evaluate the score and output some metrics\n", | |
"scores = model.evaluate(X.values, Y.values)\n", | |
"[(model.metrics_names[i], scores[i]) for i in range(0, len(scores))]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 92, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0.00854679],\n", | |
" [ 0.99831319],\n", | |
" [ 0.98393595],\n", | |
" [ 0.00744388]], dtype=float32)" | |
] | |
}, | |
"execution_count": 92, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# manually check the predictions for all inputs\n", | |
"model.predict(X.values)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment