Created
July 16, 2017 22:22
-
-
Save tomtung/c030219cdb731ad67be00cb049b5dc22 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using TensorFlow backend.\n" | |
] | |
} | |
], | |
"source": [ | |
"import keras\n", | |
"import numpy as np\n", | |
"import random" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"N_DIGITS = 5\n", | |
"INPUT_LEN = N_DIGITS * 2 + 1\n", | |
"OUTPUT_LEN = N_DIGITS + 1\n", | |
"\n", | |
"CHARS = list(' 1234567890+')\n", | |
"CHAR_TO_INDEX = {\n", | |
" c: i\n", | |
" for i, c in enumerate(CHARS)\n", | |
"}\n", | |
"\n", | |
"TRAIN_DATA_SIZE = 600000\n", | |
"TEST_DATA_SIZE = 100000" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Data Generation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"A random number: 3\n" | |
] | |
} | |
], | |
"source": [ | |
"def generate_random_number():\n", | |
" return random.randrange(0, 10 ** random.randint(1, N_DIGITS))\n", | |
"\n", | |
"print('A random number: {}'.format(generate_random_number()))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(8, 27988)\n" | |
] | |
} | |
], | |
"source": [ | |
"def generate_addend_pair():\n", | |
" return generate_random_number(), generate_random_number()\n", | |
"\n", | |
"print(generate_addend_pair())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"An example: ('12+345 ', '357 ')\n" | |
] | |
} | |
], | |
"source": [ | |
"def generate_str_example(x, y):\n", | |
" input_str = '{}+{}'.format(x, y)\n", | |
" output_str = str(x + y)\n", | |
" \n", | |
" input_format_str = '{{:{}}}'.format(INPUT_LEN)\n", | |
" input_str = input_format_str.format(input_str)\n", | |
" \n", | |
" output_format_str = '{{:{}}}'.format(OUTPUT_LEN)\n", | |
" output_str = output_format_str.format(output_str)\n", | |
" \n", | |
" return input_str, output_str\n", | |
"\n", | |
"print('An example: {}'.format(generate_str_example(12, 345)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[(11, 12), (6, 12)]\n" | |
] | |
} | |
], | |
"source": [ | |
"def generate_example(x, y):\n", | |
" input_str, output_str = generate_str_example(x, y)\n", | |
"\n", | |
" input_ = np.zeros((INPUT_LEN, len(CHARS)))\n", | |
" for i, c in enumerate(input_str):\n", | |
" index = CHAR_TO_INDEX[c]\n", | |
" input_[i, index] = 1\n", | |
"\n", | |
" output = np.zeros((OUTPUT_LEN, len(CHARS)))\n", | |
" for i, c in enumerate(output_str):\n", | |
" index = CHAR_TO_INDEX[c]\n", | |
" output[i, index] = 1\n", | |
"\n", | |
" return input_, output\n", | |
"\n", | |
"print([array.shape for array in generate_example(12, 345)])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"training_x shape: (600000, 11, 12)\n", | |
"training_y shape: (600000, 6, 12)\n", | |
"testing_x shape: (100000, 11, 12)\n", | |
"testing_y shape: (100000, 6, 12)\n" | |
] | |
} | |
], | |
"source": [ | |
"def generate_examples(n_train, n_test):\n", | |
" n_examples = n_train + n_test\n", | |
" \n", | |
" addend_pairs = set()\n", | |
" while len(addend_pairs) < n_examples:\n", | |
" addend_pairs.add(generate_addend_pair())\n", | |
" \n", | |
" inputs, outputs = zip(*[\n", | |
" generate_example(x, y)\n", | |
" for x, y in addend_pairs\n", | |
" ])\n", | |
" \n", | |
" return np.array(inputs[:n_train]), np.array(outputs[:n_train]), np.array(inputs[n_train:]), np.array(outputs[n_train:])\n", | |
"\n", | |
"training_x, training_y, testing_x, testing_y = generate_examples(TRAIN_DATA_SIZE, TEST_DATA_SIZE)\n", | |
"print('training_x shape:', training_x.shape)\n", | |
"print('training_y shape:', training_y.shape)\n", | |
"print('testing_x shape:', testing_x.shape)\n", | |
"print('testing_y shape:', testing_y.shape)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"HIDDEN_SIZE = 128\n", | |
"BATCH_SIZE = 128\n", | |
"MAX_N_EPOCS = 1000" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model = keras.models.Sequential([\n", | |
" keras.layers.wrappers.Bidirectional(\n", | |
" keras.layers.recurrent.LSTM(HIDDEN_SIZE),\n", | |
" input_shape=(INPUT_LEN, len(CHARS))\n", | |
" ),\n", | |
" keras.layers.core.RepeatVector(OUTPUT_LEN),\n", | |
" keras.layers.recurrent.LSTM(HIDDEN_SIZE, return_sequences=True),\n", | |
" keras.layers.wrappers.TimeDistributed(\n", | |
" keras.layers.Dense(len(CHARS), activation='softmax')\n", | |
" ),\n", | |
"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"bidirectional_1 (Bidirection (None, 256) 144384 \n", | |
"_________________________________________________________________\n", | |
"repeat_vector_1 (RepeatVecto (None, 6, 256) 0 \n", | |
"_________________________________________________________________\n", | |
"lstm_2 (LSTM) (None, 6, 128) 197120 \n", | |
"_________________________________________________________________\n", | |
"time_distributed_1 (TimeDist (None, 6, 12) 1548 \n", | |
"=================================================================\n", | |
"Total params: 343,052\n", | |
"Trainable params: 343,052\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
], | |
"source": [ | |
"model.summary()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Train on 480000 samples, validate on 120000 samples\n", | |
"Epoch 1/1000\n", | |
"157s - loss: 1.1147 - acc: 0.5823 - val_loss: 0.5931 - val_acc: 0.8039\n", | |
"Epoch 2/1000\n", | |
"146s - loss: 0.4039 - acc: 0.8568 - val_loss: 0.3069 - val_acc: 0.8849\n", | |
"Epoch 3/1000\n", | |
"142s - loss: 0.1948 - acc: 0.9316 - val_loss: 0.1232 - val_acc: 0.9608\n", | |
"Epoch 4/1000\n", | |
"142s - loss: 0.0904 - acc: 0.9708 - val_loss: 0.0693 - val_acc: 0.9774\n", | |
"Epoch 5/1000\n", | |
"142s - loss: 0.0559 - acc: 0.9816 - val_loss: 0.0418 - val_acc: 0.9864\n", | |
"Epoch 6/1000\n", | |
"142s - loss: 0.0373 - acc: 0.9880 - val_loss: 0.0304 - val_acc: 0.9900\n", | |
"Epoch 7/1000\n", | |
"142s - loss: 0.0280 - acc: 0.9911 - val_loss: 0.0195 - val_acc: 0.9941\n", | |
"Epoch 8/1000\n", | |
"142s - loss: 0.0210 - acc: 0.9935 - val_loss: 0.0241 - val_acc: 0.9919\n", | |
"Epoch 9/1000\n", | |
"142s - loss: 0.0181 - acc: 0.9946 - val_loss: 0.0103 - val_acc: 0.9970\n", | |
"Epoch 10/1000\n", | |
"141s - loss: 0.0137 - acc: 0.9960 - val_loss: 0.0108 - val_acc: 0.9967\n", | |
"Epoch 11/1000\n", | |
"141s - loss: 0.0143 - acc: 0.9959 - val_loss: 0.0148 - val_acc: 0.9958\n", | |
"Epoch 12/1000\n", | |
"141s - loss: 0.0114 - acc: 0.9968 - val_loss: 0.0046 - val_acc: 0.9988\n", | |
"Epoch 13/1000\n", | |
"141s - loss: 0.0110 - acc: 0.9968 - val_loss: 0.0067 - val_acc: 0.9980\n", | |
"Epoch 14/1000\n", | |
"141s - loss: 0.0076 - acc: 0.9978 - val_loss: 0.0048 - val_acc: 0.9986\n", | |
"Epoch 15/1000\n", | |
"141s - loss: 0.0093 - acc: 0.9975 - val_loss: 0.0081 - val_acc: 0.9975\n", | |
"Epoch 16/1000\n", | |
"141s - loss: 0.0073 - acc: 0.9979 - val_loss: 0.0080 - val_acc: 0.9975\n", | |
"Epoch 17/1000\n", | |
"141s - loss: 0.0059 - acc: 0.9983 - val_loss: 0.0043 - val_acc: 0.9987\n", | |
"Epoch 18/1000\n", | |
"141s - loss: 0.0068 - acc: 0.9981 - val_loss: 0.0046 - val_acc: 0.9986\n", | |
"Epoch 19/1000\n", | |
"141s - loss: 0.0058 - acc: 0.9984 - val_loss: 0.0044 - val_acc: 0.9987\n", | |
"Epoch 20/1000\n", | |
"141s - loss: 0.0064 - acc: 0.9982 - val_loss: 0.0094 - val_acc: 0.9972\n", | |
"Epoch 21/1000\n", | |
"141s - loss: 0.0053 - acc: 0.9985 - val_loss: 0.0039 - val_acc: 0.9989\n", | |
"Epoch 22/1000\n", | |
"141s - loss: 0.0041 - acc: 0.9989 - val_loss: 0.0042 - val_acc: 0.9987\n", | |
"Epoch 23/1000\n", | |
"141s - loss: 0.0050 - acc: 0.9986 - val_loss: 0.0172 - val_acc: 0.9949\n", | |
"Epoch 24/1000\n", | |
"141s - loss: 0.0038 - acc: 0.9989 - val_loss: 0.0033 - val_acc: 0.9990\n", | |
"Epoch 25/1000\n", | |
"141s - loss: 0.0051 - acc: 0.9987 - val_loss: 0.0020 - val_acc: 0.9995\n", | |
"Epoch 26/1000\n", | |
"142s - loss: 0.0042 - acc: 0.9988 - val_loss: 0.0023 - val_acc: 0.9994\n", | |
"Epoch 27/1000\n", | |
"141s - loss: 0.0044 - acc: 0.9988 - val_loss: 0.0018 - val_acc: 0.9995\n", | |
"Epoch 28/1000\n", | |
"141s - loss: 0.0032 - acc: 0.9991 - val_loss: 0.0029 - val_acc: 0.9992\n", | |
"Epoch 29/1000\n", | |
"144s - loss: 0.0042 - acc: 0.9988 - val_loss: 0.0085 - val_acc: 0.9974\n", | |
"Epoch 30/1000\n", | |
"144s - loss: 0.0033 - acc: 0.9991 - val_loss: 0.0019 - val_acc: 0.9995\n", | |
"Epoch 31/1000\n", | |
"157s - loss: 0.0039 - acc: 0.9990 - val_loss: 0.0014 - val_acc: 0.9997\n", | |
"Epoch 32/1000\n", | |
"154s - loss: 0.0028 - acc: 0.9992 - val_loss: 0.0033 - val_acc: 0.9991\n", | |
"Epoch 33/1000\n", | |
"150s - loss: 0.0031 - acc: 0.9992 - val_loss: 0.0013 - val_acc: 0.9997\n", | |
"Epoch 34/1000\n", | |
"154s - loss: 0.0028 - acc: 0.9992 - val_loss: 0.0024 - val_acc: 0.9993\n", | |
"Epoch 35/1000\n", | |
"152s - loss: 0.0032 - acc: 0.9992 - val_loss: 0.0038 - val_acc: 0.9988\n", | |
"Epoch 36/1000\n", | |
"157s - loss: 0.0026 - acc: 0.9993 - val_loss: 0.0037 - val_acc: 0.9989\n", | |
"Epoch 37/1000\n", | |
"152s - loss: 0.0024 - acc: 0.9993 - val_loss: 0.0016 - val_acc: 0.9996\n", | |
"Epoch 38/1000\n", | |
"153s - loss: 0.0031 - acc: 0.9992 - val_loss: 0.0024 - val_acc: 0.9992\n", | |
"Epoch 39/1000\n", | |
"155s - loss: 0.0025 - acc: 0.9993 - val_loss: 0.0031 - val_acc: 0.9990\n", | |
"Epoch 00038: early stopping\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<keras.callbacks.History at 0x1a2ad1d6e48>" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.fit(\n", | |
" training_x, training_y,\n", | |
" batch_size=BATCH_SIZE,\n", | |
" epochs=MAX_N_EPOCS,\n", | |
" verbose=2,\n", | |
" validation_split=.2,\n", | |
" callbacks=[\n", | |
" keras.callbacks.EarlyStopping(patience=5, verbose=2),\n", | |
" ],\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"100000/100000 [==============================] - 32s \n", | |
"\n", | |
"Test loss: 0.0030354137425270163\n", | |
"Test acc: 0.9990433450508117\n" | |
] | |
} | |
], | |
"source": [ | |
"metrics_vals = model.evaluate(testing_x, testing_y)\n", | |
"\n", | |
"print('')\n", | |
"for metric_name, metric_val in zip(model.metrics_names, metrics_vals):\n", | |
" print('Test {}: {}'.format(metric_name, metric_val))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Model In Action" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def neural_addition(x, y):\n", | |
" input_, _ = generate_example(x, y)\n", | |
" output_ = model.predict_on_batch(np.array([input_]))[0]\n", | |
" indices = np.argmax(output_, axis=1)\n", | |
" return ''.join(CHARS[index] for index in indices)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"163 + 0 = \"163 \" (correct)\n", | |
"96 + 453 = \"549 \" (correct)\n", | |
"69 + 557 = \"626 \" (correct)\n", | |
"7721 + 98 = \"7819 \" (correct)\n", | |
"5112 + 79646 = \"84758 \" (correct)\n", | |
"493 + 43044 = \"43537 \" (correct)\n", | |
"51 + 489 = \"540 \" (correct)\n", | |
"84628 + 3457 = \"88085 \" (correct)\n", | |
"1 + 2236 = \"2237 \" (correct)\n", | |
"0 + 4622 = \"4622 \" (correct)\n", | |
"67 + 0 = \"67 \" (correct)\n", | |
"90642 + 68 = \"90710 \" (correct)\n", | |
"6 + 6 = \"12 \" (correct)\n", | |
"38973 + 23 = \"38996 \" (correct)\n", | |
"4 + 5945 = \"5949 \" (correct)\n", | |
"155 + 321 = \"476 \" (correct)\n", | |
"4987 + 2805 = \"7792 \" (correct)\n", | |
"70001 + 8 = \"70009 \" (correct)\n", | |
"1085 + 36 = \"1121 \" (correct)\n", | |
"13 + 2969 = \"2982 \" (correct)\n" | |
] | |
} | |
], | |
"source": [ | |
"for _ in range(20):\n", | |
" x, y = generate_addend_pair()\n", | |
" expected = x + y\n", | |
" result = neural_addition(x, y)\n", | |
" if result.strip() == str(expected):\n", | |
" print('{} + {} = \"{}\" (correct)'.format(x, y, result))\n", | |
" else:\n", | |
" print('{} + {} = \"{}\" (incorrect, should be {})'.format(x, y, result, expected))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment