Last active
August 31, 2017 01:51
-
-
Save izmailovpavel/7732a2ef6634aaade16a5a24b95b1c66 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from matplotlib import pyplot as plt\n", | |
"\n", | |
"%matplotlib inline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 85, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"d = 2\n", | |
"seq_len = 5\n", | |
"n_examples = 100\n", | |
"x = np.random.normal(size=(n_examples, seq_len, d))\n", | |
"w = np.random.normal(size=(d, 1))\n", | |
"y = np.sum(x.dot(w), axis=1)\n", | |
"\n", | |
"n_te = 100\n", | |
"x_te = np.random.normal(size=(n_te, d))\n", | |
"y_te = x_te.dot(w)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 86, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from keras.models import Model\n", | |
"from keras.layers import Dense, Input, Concatenate\n", | |
"from keras.layers.merge import add\n", | |
"from keras.optimizers import SGD" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 111, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"inputs = []\n", | |
"for input_no in range(seq_len):\n", | |
" inputs.append(Input(shape=(d,)))\n", | |
" \n", | |
"linear_layer = Dense(1)\n", | |
"\n", | |
"rs = [linear_layer(input_) for input_ in inputs]\n", | |
"\n", | |
"output = add(rs)#, mode='sum', output_shape=(1,))\n", | |
"net = Model(inputs, output)\n", | |
"net_one_input = Model(inputs[0], rs[0])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 112, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"____________________________________________________________________________________________________\n", | |
"Layer (type) Output Shape Param # Connected to \n", | |
"====================================================================================================\n", | |
"input_86 (InputLayer) (None, 2) 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"input_87 (InputLayer) (None, 2) 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"input_88 (InputLayer) (None, 2) 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"input_89 (InputLayer) (None, 2) 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"input_90 (InputLayer) (None, 2) 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"dense_25 (Dense) (None, 1) 3 input_86[0][0] \n", | |
" input_87[0][0] \n", | |
" input_88[0][0] \n", | |
" input_89[0][0] \n", | |
" input_90[0][0] \n", | |
"____________________________________________________________________________________________________\n", | |
"add_11 (Add) (None, 1) 0 dense_25[0][0] \n", | |
" dense_25[1][0] \n", | |
" dense_25[2][0] \n", | |
" dense_25[3][0] \n", | |
" dense_25[4][0] \n", | |
"====================================================================================================\n", | |
"Total params: 3\n", | |
"Trainable params: 3\n", | |
"Non-trainable params: 0\n", | |
"____________________________________________________________________________________________________\n" | |
] | |
} | |
], | |
"source": [ | |
"net.compile(loss=\"mse\", optimizer=SGD(lr=0.01))\n", | |
"net.summary()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 113, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1/20\n", | |
"100/100 [==============================] - 0s - loss: 3.8750 \n", | |
"Epoch 2/20\n", | |
"100/100 [==============================] - 0s - loss: 1.9068 \n", | |
"Epoch 3/20\n", | |
"100/100 [==============================] - 0s - loss: 0.9260 \n", | |
"Epoch 4/20\n", | |
"100/100 [==============================] - 0s - loss: 0.4834 \n", | |
"Epoch 5/20\n", | |
"100/100 [==============================] - 0s - loss: 0.2492 \n", | |
"Epoch 6/20\n", | |
"100/100 [==============================] - 0s - loss: 0.0905 \n", | |
"Epoch 7/20\n", | |
"100/100 [==============================] - 0s - loss: 0.0361 \n", | |
"Epoch 8/20\n", | |
"100/100 [==============================] - 0s - loss: 0.0173 \n", | |
"Epoch 9/20\n", | |
"100/100 [==============================] - 0s - loss: 0.0083 \n", | |
"Epoch 10/20\n", | |
"100/100 [==============================] - 0s - loss: 0.0042 \n", | |
"Epoch 11/20\n", | |
"100/100 [==============================] - 0s - loss: 0.0019 \n", | |
"Epoch 12/20\n", | |
"100/100 [==============================] - 0s - loss: 0.0010 \n", | |
"Epoch 13/20\n", | |
"100/100 [==============================] - 0s - loss: 5.4117e-04 \n", | |
"Epoch 14/20\n", | |
"100/100 [==============================] - 0s - loss: 3.0929e-04 \n", | |
"Epoch 15/20\n", | |
"100/100 [==============================] - 0s - loss: 1.6779e-04 \n", | |
"Epoch 16/20\n", | |
"100/100 [==============================] - 0s - loss: 8.8566e-05 \n", | |
"Epoch 17/20\n", | |
"100/100 [==============================] - 0s - loss: 3.1254e-05 \n", | |
"Epoch 18/20\n", | |
"100/100 [==============================] - 0s - loss: 1.4281e-05 \n", | |
"Epoch 19/20\n", | |
"100/100 [==============================] - 0s - loss: 6.8632e-06 \n", | |
"Epoch 20/20\n", | |
"100/100 [==============================] - 0s - loss: 3.3126e-06 \n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<keras.callbacks.History at 0x13c10c978>" | |
] | |
}, | |
"execution_count": 113, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"net.fit(x=[x[:, i, :] for i in range(seq_len)], y=y, epochs=20)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 115, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.metrics import r2_score" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 116, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.99999957519803795" | |
] | |
}, | |
"execution_count": 116, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"r2_score(y_te, net_one_input.predict(x_te))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.4.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment