Created
January 5, 2018 11:32
-
-
Save Erlemar/628aedfdcf94055b98acd28acff44840 to your computer and use it in GitHub Desktop.
Simple feedforward net with 3 classes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x = np.random.random((20, 2))\n", | |
"y = np.zeros((20, 3))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0.09654034, 0.05141115],\n", | |
" [ 0.04606768, 0.14275332],\n", | |
" [ 0.84532179, 0.3243917 ],\n", | |
" [ 0.94953811, 0.81377796],\n", | |
" [ 0.85867994, 0.40255299],\n", | |
" [ 0.59481875, 0.91351973],\n", | |
" [ 0.8067839 , 0.86121881],\n", | |
" [ 0.01405074, 0.79912704],\n", | |
" [ 0.06295606, 0.25968421],\n", | |
" [ 0.50689196, 0.32527316],\n", | |
" [ 0.61683513, 0.41687511],\n", | |
" [ 0.40018815, 0.87730548],\n", | |
" [ 0.52215515, 0.93456424],\n", | |
" [ 0.21109933, 0.23949889],\n", | |
" [ 0.98061506, 0.89467517],\n", | |
" [ 0.55564379, 0.26497862],\n", | |
" [ 0.1102324 , 0.74555352],\n", | |
" [ 0.67752377, 0.43244897],\n", | |
" [ 0.96567994, 0.46514039],\n", | |
" [ 0.57240303, 0.60615954]])" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"y[np.sum(x, axis=1) < 0.5, 0] = 1\n", | |
"y[(np.sum(x, axis=1) >= 0.5) & (np.sum(x, axis=1) < 1), 1] = 1\n", | |
"y[np.sum(x, axis=1) >= 1, 2] = 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 1., 0., 0.],\n", | |
" [ 1., 0., 0.],\n", | |
" [ 0., 0., 1.],\n", | |
" [ 0., 0., 1.],\n", | |
" [ 0., 0., 1.],\n", | |
" [ 0., 0., 1.],\n", | |
" [ 0., 0., 1.],\n", | |
" [ 0., 1., 0.],\n", | |
" [ 1., 0., 0.],\n", | |
" [ 0., 1., 0.],\n", | |
" [ 0., 0., 1.],\n", | |
" [ 0., 0., 1.],\n", | |
" [ 0., 0., 1.],\n", | |
" [ 1., 0., 0.],\n", | |
" [ 0., 0., 1.],\n", | |
" [ 0., 1., 0.],\n", | |
" [ 0., 1., 0.],\n", | |
" [ 0., 0., 1.],\n", | |
" [ 0., 0., 1.],\n", | |
" [ 0., 0., 1.]])" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Error:0.444505472619\n", | |
"Error:0.00616825647943\n", | |
"Error:0.00412845147618\n", | |
"Error:0.00333792186315\n", | |
"Error:0.00288331427728\n", | |
"Error:0.0025775546503\n", | |
"Error:0.0023533986912\n", | |
"Error:0.00217978768693\n", | |
"Error:0.00204008904013\n", | |
"Error:0.00192446969304\n", | |
"[[ 9.98605229e-01 1.39476876e-03 2.73711302e-09]\n", | |
" [ 9.98594653e-01 1.40534421e-03 2.75635821e-09]\n", | |
" [ 4.09460048e-13 1.85230329e-04 9.99814770e-01]\n", | |
" [ 1.10368135e-13 1.98437990e-05 9.99980156e-01]\n", | |
" [ 1.46035996e-13 2.99785919e-05 9.99970021e-01]\n", | |
" [ 1.12261587e-13 2.02938728e-05 9.99979706e-01]\n", | |
" [ 1.10547429e-13 1.98831990e-05 9.99980117e-01]\n", | |
" [ 5.03180608e-03 9.94948361e-01 1.98330254e-05]\n", | |
" [ 9.98224838e-01 1.77515945e-03 3.01689258e-09]\n", | |
" [ 6.69091814e-04 9.86878971e-01 1.24519376e-02]\n", | |
" [ 3.49494280e-12 1.07574726e-02 9.89242527e-01]\n", | |
" [ 1.47093099e-13 3.23512304e-05 9.99967649e-01]\n", | |
" [ 1.14482110e-13 2.09143093e-05 9.99979086e-01]\n", | |
" [ 9.88585961e-01 1.14140341e-02 4.43430083e-09]\n", | |
" [ 1.10199201e-13 1.98283563e-05 9.99980172e-01]\n", | |
" [ 6.63185716e-03 9.93194072e-01 1.74070848e-04]\n", | |
" [ 4.59075665e-04 9.97490908e-01 2.05001599e-03]\n", | |
" [ 3.74500662e-13 1.62437353e-04 9.99837563e-01]\n", | |
" [ 1.15843485e-13 2.08732467e-05 9.99979127e-01]\n", | |
" [ 1.56929616e-13 3.50518507e-05 9.99964948e-01]]\n" | |
] | |
} | |
], | |
"source": [ | |
"w = np.random.random((2, 16))\n", | |
"w2 = np.random.random((16, 8))\n", | |
"w3 = np.random.random((8, 3))\n", | |
"\n", | |
"learning_rate = 0.1\n", | |
"\n", | |
"for j in range(50000):\n", | |
" a2 = 1/(1 + np.exp(-(np.dot(x, w))))\n", | |
" a3 = 1/(1 + np.exp(-(np.dot(a2, w2))))\n", | |
" a4 = 1/(1 + np.exp(-(np.dot(a3, w3))))\n", | |
" a4 = a4 / np.sum(a4, axis=1, keepdims=True)\n", | |
" a4delta = (y - a4) * (a4 * (1 - a4))\n", | |
" a3delta = a4delta.dot(w3.T) * (a3 * (1 - a3))\n", | |
" a2delta = a3delta.dot(w2.T) * (a2 * (1 - a2))\n", | |
" if (j % 5000) == 0:\n", | |
" print (\"Error:\" + str(np.mean(np.abs(y - a4))))\n", | |
" w3 += a3.T.dot(a4delta) * learning_rate\n", | |
" w2 += a2.T.dot(a3delta) * learning_rate\n", | |
" w += x.T.dot(a2delta) * learning_rate\n", | |
"print(a4)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python [conda root]", | |
"language": "python", | |
"name": "conda-root-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment