Skip to content

Instantly share code, notes, and snippets.

@Erlemar
Created January 5, 2018 11:32
Show Gist options
  • Save Erlemar/628aedfdcf94055b98acd28acff44840 to your computer and use it in GitHub Desktop.
Save Erlemar/628aedfdcf94055b98acd28acff44840 to your computer and use it in GitHub Desktop.
Simple feedforward net with 3 classes.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"x = np.random.random((20, 2))\n",
"y = np.zeros((20, 3))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.09654034, 0.05141115],\n",
" [ 0.04606768, 0.14275332],\n",
" [ 0.84532179, 0.3243917 ],\n",
" [ 0.94953811, 0.81377796],\n",
" [ 0.85867994, 0.40255299],\n",
" [ 0.59481875, 0.91351973],\n",
" [ 0.8067839 , 0.86121881],\n",
" [ 0.01405074, 0.79912704],\n",
" [ 0.06295606, 0.25968421],\n",
" [ 0.50689196, 0.32527316],\n",
" [ 0.61683513, 0.41687511],\n",
" [ 0.40018815, 0.87730548],\n",
" [ 0.52215515, 0.93456424],\n",
" [ 0.21109933, 0.23949889],\n",
" [ 0.98061506, 0.89467517],\n",
" [ 0.55564379, 0.26497862],\n",
" [ 0.1102324 , 0.74555352],\n",
" [ 0.67752377, 0.43244897],\n",
" [ 0.96567994, 0.46514039],\n",
" [ 0.57240303, 0.60615954]])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"y[np.sum(x, axis=1) < 0.5, 0] = 1\n",
"y[(np.sum(x, axis=1) >= 0.5) & (np.sum(x, axis=1) < 1), 1] = 1\n",
"y[np.sum(x, axis=1) >= 1, 2] = 1"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 1., 0., 0.],\n",
" [ 1., 0., 0.],\n",
" [ 0., 0., 1.],\n",
" [ 0., 0., 1.],\n",
" [ 0., 0., 1.],\n",
" [ 0., 0., 1.],\n",
" [ 0., 0., 1.],\n",
" [ 0., 1., 0.],\n",
" [ 1., 0., 0.],\n",
" [ 0., 1., 0.],\n",
" [ 0., 0., 1.],\n",
" [ 0., 0., 1.],\n",
" [ 0., 0., 1.],\n",
" [ 1., 0., 0.],\n",
" [ 0., 0., 1.],\n",
" [ 0., 1., 0.],\n",
" [ 0., 1., 0.],\n",
" [ 0., 0., 1.],\n",
" [ 0., 0., 1.],\n",
" [ 0., 0., 1.]])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Error:0.444505472619\n",
"Error:0.00616825647943\n",
"Error:0.00412845147618\n",
"Error:0.00333792186315\n",
"Error:0.00288331427728\n",
"Error:0.0025775546503\n",
"Error:0.0023533986912\n",
"Error:0.00217978768693\n",
"Error:0.00204008904013\n",
"Error:0.00192446969304\n",
"[[ 9.98605229e-01 1.39476876e-03 2.73711302e-09]\n",
" [ 9.98594653e-01 1.40534421e-03 2.75635821e-09]\n",
" [ 4.09460048e-13 1.85230329e-04 9.99814770e-01]\n",
" [ 1.10368135e-13 1.98437990e-05 9.99980156e-01]\n",
" [ 1.46035996e-13 2.99785919e-05 9.99970021e-01]\n",
" [ 1.12261587e-13 2.02938728e-05 9.99979706e-01]\n",
" [ 1.10547429e-13 1.98831990e-05 9.99980117e-01]\n",
" [ 5.03180608e-03 9.94948361e-01 1.98330254e-05]\n",
" [ 9.98224838e-01 1.77515945e-03 3.01689258e-09]\n",
" [ 6.69091814e-04 9.86878971e-01 1.24519376e-02]\n",
" [ 3.49494280e-12 1.07574726e-02 9.89242527e-01]\n",
" [ 1.47093099e-13 3.23512304e-05 9.99967649e-01]\n",
" [ 1.14482110e-13 2.09143093e-05 9.99979086e-01]\n",
" [ 9.88585961e-01 1.14140341e-02 4.43430083e-09]\n",
" [ 1.10199201e-13 1.98283563e-05 9.99980172e-01]\n",
" [ 6.63185716e-03 9.93194072e-01 1.74070848e-04]\n",
" [ 4.59075665e-04 9.97490908e-01 2.05001599e-03]\n",
" [ 3.74500662e-13 1.62437353e-04 9.99837563e-01]\n",
" [ 1.15843485e-13 2.08732467e-05 9.99979127e-01]\n",
" [ 1.56929616e-13 3.50518507e-05 9.99964948e-01]]\n"
]
}
],
"source": [
"w = np.random.random((2, 16))\n",
"w2 = np.random.random((16, 8))\n",
"w3 = np.random.random((8, 3))\n",
"\n",
"learning_rate = 0.1\n",
"\n",
"for j in range(50000):\n",
" a2 = 1/(1 + np.exp(-(np.dot(x, w))))\n",
" a3 = 1/(1 + np.exp(-(np.dot(a2, w2))))\n",
" a4 = 1/(1 + np.exp(-(np.dot(a3, w3))))\n",
" a4 = a4 / np.sum(a4, axis=1, keepdims=True)\n",
" a4delta = (y - a4) * (a4 * (1 - a4))\n",
" a3delta = a4delta.dot(w3.T) * (a3 * (1 - a3))\n",
" a2delta = a3delta.dot(w2.T) * (a2 * (1 - a2))\n",
" if (j % 5000) == 0:\n",
" print (\"Error:\" + str(np.mean(np.abs(y - a4))))\n",
" w3 += a3.T.dot(a4delta) * learning_rate\n",
" w2 += a2.T.dot(a3delta) * learning_rate\n",
" w += x.T.dot(a2delta) * learning_rate\n",
"print(a4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda root]",
"language": "python",
"name": "conda-root-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment