Created
August 10, 2021 10:47
-
-
Save 79man/4bd1bd1afa0589b26445954f1514bddc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 142, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from __future__ import print_function\n", | |
"from keras.models import Sequential\n", | |
"from keras import layers\n", | |
"import numpy as np\n", | |
"from six.moves import range\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"class CharacterTable(object):\n", | |
" \"\"\"Given a set of characters:\n", | |
" + Encode them to a one-hot integer representation\n", | |
" + Decode the one-hot or integer representation to their character output\n", | |
" + Decode a vector of probabilities to their character output\n", | |
" \"\"\"\n", | |
" def __init__(self, chars):\n", | |
" \"\"\"Initialize character table.\n", | |
"\n", | |
" # Arguments\n", | |
" chars: Characters that can appear in the input.\n", | |
" \"\"\"\n", | |
" self.chars = sorted(set(chars))\n", | |
" self.char_indices = dict((c, i) for i, c in enumerate(self.chars))\n", | |
" self.indices_char = dict((i, c) for i, c in enumerate(self.chars))\n", | |
"\n", | |
" def encode(self, C, num_rows):\n", | |
" \"\"\"One-hot encode given string C.\n", | |
"\n", | |
" # Arguments\n", | |
" C: string, to be encoded.\n", | |
" num_rows: Number of rows in the returned one-hot encoding. This is\n", | |
" used to keep the # of rows for each data the same.\n", | |
" \"\"\"\n", | |
" x = np.zeros((num_rows, len(self.chars)))\n", | |
" for i, c in enumerate(C):\n", | |
" x[i, self.char_indices[c]] = 1\n", | |
" return x\n", | |
"\n", | |
" def decode(self, x, calc_argmax=True):\n", | |
" \"\"\"Decode the given vector or 2D array to their character output.\n", | |
"\n", | |
" # Arguments\n", | |
" x: A vector or a 2D array of probabilities or one-hot representations;\n", | |
" or a vector of character indices (used with `calc_argmax=False`).\n", | |
" calc_argmax: Whether to find the character index with maximum\n", | |
" probability, defaults to `True`.\n", | |
" \"\"\"\n", | |
" if calc_argmax:\n", | |
" x = x.argmax(axis=-1)\n", | |
" return ''.join(self.indices_char[x] for x in x)\n", | |
"\n", | |
"class colors:\n", | |
" ok = '\\033[92m'\n", | |
" fail = '\\033[91m'\n", | |
" close = '\\033[0m'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 143, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Step - 1 : Setiing up the Parameters\n", | |
"# Parameters for the model and dataset.\n", | |
"TRAINING_SIZE = 50000\n", | |
"DIGITS = 3\n", | |
"REVERSE = True\n", | |
"\n", | |
"# Maximum length of input is 'int + int' (e.g., '345+678'). Maximum length of\n", | |
"# int is DIGITS.\n", | |
"MAXLEN = DIGITS + 1 + DIGITS\n", | |
"\n", | |
"# All the numbers, plus sign and space for padding.\n", | |
"chars = '0123456789+ '\n", | |
"ctable = CharacterTable(chars)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 144, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Generating data...\n", | |
"question : reversed : training : actual\n", | |
"'15+91' : ' 19+51' : '106 ' : '106'\n", | |
"'556+7' : ' 7+655' : '563 ' : '563'\n", | |
"'91+11' : ' 11+19' : '102 ' : '102'\n", | |
"'18+776' : ' 677+81' : '794 ' : '794'\n", | |
"'32+357' : ' 753+23' : '389 ' : '389'\n", | |
"'242+7' : ' 7+242' : '249 ' : '249'\n", | |
"'508+67' : ' 76+805' : '575 ' : '575'\n", | |
"'990+77' : ' 77+099' : '1067' : '1067'\n", | |
"'2+187' : ' 781+2' : '189 ' : '189'\n", | |
"'96+59' : ' 95+69' : '155 ' : '155'\n", | |
"'904+4' : ' 4+409' : '908 ' : '908'\n", | |
"'51+871' : ' 178+15' : '922 ' : '922'\n", | |
"'669+8' : ' 8+966' : '677 ' : '677'\n", | |
"'88+829' : ' 928+88' : '917 ' : '917'\n", | |
"'733+883' : '388+337' : '1616' : '1616'\n", | |
"'896+793' : '397+698' : '1689' : '1689'\n", | |
"'132+4' : ' 4+231' : '136 ' : '136'\n", | |
"'85+380' : ' 083+58' : '465 ' : '465'\n", | |
"'139+975' : '579+931' : '1114' : '1114'\n", | |
"'970+43' : ' 34+079' : '1013' : '1013'\n", | |
"'817+846' : '648+718' : '1663' : '1663'\n", | |
"'621+5' : ' 5+126' : '626 ' : '626'\n", | |
"'46+26' : ' 62+64' : '72 ' : '72'\n", | |
"'67+795' : ' 597+76' : '862 ' : '862'\n", | |
"'881+244' : '442+188' : '1125' : '1125'\n", | |
"'91+919' : ' 919+19' : '1010' : '1010'\n", | |
"'91+520' : ' 025+19' : '611 ' : '611'\n", | |
"'0+812' : ' 218+0' : '812 ' : '812'\n", | |
"'52+61' : ' 16+25' : '113 ' : '113'\n", | |
"'31+919' : ' 919+13' : '950 ' : '950'\n", | |
"'522+40' : ' 04+225' : '562 ' : '562'\n", | |
"'17+553' : ' 355+71' : '570 ' : '570'\n", | |
"'23+936' : ' 639+32' : '959 ' : '959'\n", | |
"'155+81' : ' 18+551' : '236 ' : '236'\n", | |
"'759+446' : '644+957' : '1205' : '1205'\n", | |
"'96+668' : ' 866+69' : '764 ' : '764'\n", | |
"'145+681' : '186+541' : '826 ' : '826'\n", | |
"'1+990' : ' 099+1' : '991 ' : '991'\n", | |
"'5+339' : ' 933+5' : '344 ' : '344'\n", | |
"'70+594' : ' 495+07' : '664 ' : '664'\n", | |
"'763+62' : ' 26+367' : '825 ' : '825'\n", | |
"'54+316' : ' 613+45' : '370 ' : '370'\n", | |
"'62+446' : ' 644+26' : '508 ' : '508'\n", | |
"'601+86' : ' 68+106' : '687 ' : '687'\n", | |
"'40+930' : ' 039+04' : '970 ' : '970'\n", | |
"'104+873' : '378+401' : '977 ' : '977'\n", | |
"'885+250' : '052+588' : '1135' : '1135'\n", | |
"'24+514' : ' 415+42' : '538 ' : '538'\n", | |
"'730+63' : ' 36+037' : '793 ' : '793'\n", | |
"'427+958' : '859+724' : '1385' : '1385'\n", | |
"Total addition questions: 50000\n" | |
] | |
} | |
], | |
"source": [ | |
"# Step - 2 : Generating the Training Data - 50000 Question of len(12), 50000 Answers of len (7)\n", | |
"questions = []\n", | |
"expected = []\n", | |
"seen = set()\n", | |
"print('Generating data...')\n", | |
"\n", | |
"qCount = 0\n", | |
"print(\"{} : {} : {} : {}\".format(\"question\", \"reversed\", \"training\", \"actual\"))\n", | |
"while len(questions) < TRAINING_SIZE:\n", | |
" f = lambda: int(''.join(np.random.choice(list('0123456789'))\n", | |
" for i in range(np.random.randint(1, DIGITS + 1))))\n", | |
" a, b = f(), f()\n", | |
" # Skip any addition questions we've already seen\n", | |
" # Also skip any such that x+Y == Y+x (hence the sorting).\n", | |
" key = tuple(sorted((a, b)))\n", | |
" if key in seen:\n", | |
" continue\n", | |
" seen.add(key)\n", | |
" \n", | |
" qCount = qCount + 1\n", | |
" # Pad the data with spaces such that it is always MAXLEN.\n", | |
" q = '{}+{}'.format(a, b)\n", | |
" query = q + ' ' * (MAXLEN - len(q))\n", | |
" ans = str(a + b)\n", | |
" # Answers can be of maximum size DIGITS + 1.\n", | |
" ans += ' ' * (DIGITS + 1 - len(ans))\n", | |
" if REVERSE:\n", | |
" # Reverse the query, e.g., '12+345 ' becomes ' 543+21'. (Note the\n", | |
" # space used for padding.)\n", | |
" query = query[::-1]\n", | |
" \n", | |
" questions.append(query)\n", | |
" expected.append(ans)\n", | |
" \n", | |
" if qCount % 1000 == 0 :\n", | |
" print(\"'{}' : '{}' : '{}' : '{}'\".format(q, query, ans, a+b))\n", | |
"\n", | |
"print('Total addition questions:', len(questions))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 145, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 864x504 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Encoded Question: ' 7+43'\n", | |
"(7, 12)\n", | |
"[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | |
" [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | |
" [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | |
" [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", | |
" [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | |
" [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", | |
" [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]]'\n", | |
"Encoded Answer: '41 '\n", | |
"(4, 12)\n", | |
"[[0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", | |
" [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | |
" [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | |
" [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]'\n" | |
] | |
} | |
], | |
"source": [ | |
"# Understanding the One Hot Encoding used\n", | |
"plt.scatter([str(\"'\") + str(c) + str(\"'\") for c in list(ctable.char_indices.keys())], list(ctable.char_indices.values()))\n", | |
"plt.title(\"One Hot Encoding by Character\")\n", | |
"plt.ylabel(\"Assigned Index\")\n", | |
"plt.xlabel(\"Encoded Character\")\n", | |
"plt.show()\n", | |
"\n", | |
"t = ctable.encode(questions[0], MAXLEN)\n", | |
"print(\"Encoded Question: '{}'\\n{}\\n{}'\".format(questions[0], t.shape, t))\n", | |
"\n", | |
"t = ctable.encode(expected[0], DIGITS + 1)\n", | |
"print(\"Encoded Answer: '{}'\\n{}\\n{}'\".format(expected[0], t.shape, t))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 146, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Vectorization...\n", | |
"(50000, 7, 12) (50000, 4, 12)\n", | |
"Sample Encoded Question: ' 7+43'\n", | |
"(7, 12)\n", | |
"[[ True False False False False False False False False False False False]\n", | |
" [ True False False False False False False False False False False False]\n", | |
" [ True False False False False False False False False False False False]\n", | |
" [False False False False False False False False False True False False]\n", | |
" [False True False False False False False False False False False False]\n", | |
" [False False False False False False True False False False False False]\n", | |
" [False False False False False True False False False False False False]]'\n", | |
"Sample Encoded Answer: '41 '\n", | |
"(7, 12)\n", | |
"[[ True False False False False False False False False False False False]\n", | |
" [ True False False False False False False False False False False False]\n", | |
" [ True False False False False False False False False False False False]\n", | |
" [False False False False False False False False False True False False]\n", | |
" [False True False False False False False False False False False False]\n", | |
" [False False False False False False True False False False False False]\n", | |
" [False False False False False True False False False False False False]]'\n" | |
] | |
} | |
], | |
"source": [ | |
"# Step - 3 : Vectorization of the questions through One Hot Encoding\n", | |
"print('Vectorization...')\n", | |
"x = np.zeros((len(questions), MAXLEN, len(chars)), dtype=np.bool)\n", | |
"y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=np.bool)\n", | |
"\n", | |
"print(x.shape, y.shape)\n", | |
"\n", | |
"for i, sentence in enumerate(questions):\n", | |
" x[i] = ctable.encode(sentence, MAXLEN)\n", | |
"for i, sentence in enumerate(expected):\n", | |
" y[i] = ctable.encode(sentence, DIGITS + 1)\n", | |
" \n", | |
"print(\"Sample Encoded Question: '{}'\\n{}\\n{}'\".format(questions[0], x[0].shape, x[0]))\n", | |
"print(\"Sample Encoded Answer: '{}'\\n{}\\n{}'\".format(expected[0], x[0].shape, x[0]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 147, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((50000, 4, 12),\n", | |
" array([24763, 13191, 5224, ..., 1028, 14309, 40267]),\n", | |
" 50000,\n", | |
" 5000)" | |
] | |
}, | |
"execution_count": 147, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"testin = np.arange(len(y))\n", | |
"np.random.shuffle(testin)\n", | |
"y.shape, testin, len(x), len(x) // 10" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 148, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Training Data: (45000, 7, 12) (45000, 4, 12)\n", | |
"Validation Data: (5000, 7, 12) (5000, 4, 12)\n" | |
] | |
} | |
], | |
"source": [ | |
"# Step - 4 : Shuffle (x, y) in unison as the later parts of x will almost all be larger\n", | |
"# digits.\n", | |
"indices = np.arange(len(y))\n", | |
"np.random.shuffle(indices)\n", | |
"x = x[indices]\n", | |
"y = y[indices]\n", | |
"\n", | |
"# Explicitly set apart 10% for validation data that we never train over.\n", | |
"# '//'' is the 'divide-and-floor' operator\n", | |
"\n", | |
"split_at = len(x) - len(x) // 10\n", | |
"(x_train, x_test) = x[:split_at], x[split_at:]\n", | |
"(y_train, y_test) = y[:split_at], y[split_at:]\n", | |
"print('Training Data:', x_train.shape, y_train.shape)\n", | |
"print('Validation Data:', x_test.shape, y_test.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 149, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Build model...\n", | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"lstm_5 (LSTM) (None, 128) 72192 \n", | |
"_________________________________________________________________\n", | |
"repeat_vector_3 (RepeatVecto (None, 4, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"lstm_6 (LSTM) (None, 4, 128) 131584 \n", | |
"_________________________________________________________________\n", | |
"time_distributed_3 (TimeDist (None, 4, 12) 1548 \n", | |
"=================================================================\n", | |
"Total params: 205,324\n", | |
"Trainable params: 205,324\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
], | |
"source": [ | |
"# Step - 5 : Building the RNN\n", | |
"\n", | |
"# Try replacing GRU, or SimpleRNN.\n", | |
"RNN = layers.LSTM\n", | |
"HIDDEN_SIZE = 128\n", | |
"BATCH_SIZE = 128\n", | |
"LAYERS = 1\n", | |
"\n", | |
"print('Build model...')\n", | |
"model = Sequential()\n", | |
"# \"Encode\" the input sequence using an RNN, producing an output of HIDDEN_SIZE.\n", | |
"# Note: In a situation where your input sequences have a variable length,\n", | |
"# use input_shape=(None, num_feature).\n", | |
"model.add(RNN(HIDDEN_SIZE, input_shape=(MAXLEN, len(chars))))\n", | |
"# As the decoder RNN's input, repeatedly provide with the last output of\n", | |
"# RNN for each time step. Repeat 'DIGITS + 1' times as that's the maximum\n", | |
"# length of output, e.g., when DIGITS=3, max output is 999+999=1998.\n", | |
"model.add(layers.RepeatVector(DIGITS + 1))\n", | |
"# The decoder RNN could be multiple layers stacked or a single layer.\n", | |
"for _ in range(LAYERS):\n", | |
" # By setting return_sequences to True, return not only the last output but\n", | |
" # all the outputs so far in the form of (num_samples, timesteps,\n", | |
" # output_dim). This is necessary as TimeDistributed in the below expects\n", | |
" # the first dimension to be the timesteps.\n", | |
" model.add(RNN(HIDDEN_SIZE, return_sequences=True))\n", | |
"\n", | |
"# Apply a dense layer to the every temporal slice of an input. For each of step\n", | |
"# of the output sequence, decide which character should be chosen.\n", | |
"model.add(layers.TimeDistributed(layers.Dense(len(chars), activation='softmax')))\n", | |
"model.compile(loss='categorical_crossentropy',\n", | |
" optimizer='adam',\n", | |
" metrics=['accuracy'])\n", | |
"model.summary()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 150, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 1\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 8s 183us/step - loss: 1.8803 - acc: 0.3229 - val_loss: 1.7714 - val_acc: 0.3467\n", | |
"Q 738+405 T 1143 \u001b[91m☒\u001b[0m 100 -1043\n", | |
"Q 330+89 T 419 \u001b[91m☒\u001b[0m 100 -319\n", | |
"Q 72+32 T 104 \u001b[91m☒\u001b[0m 221 117\n", | |
"Q 16+37 T 53 \u001b[91m☒\u001b[0m 60 7\n", | |
"Q 73+421 T 494 \u001b[91m☒\u001b[0m 401 -93\n", | |
"Q 3+721 T 724 \u001b[91m☒\u001b[0m 62 -662\n", | |
"Q 649+0 T 649 \u001b[91m☒\u001b[0m 100 -549\n", | |
"Q 600+745 T 1345 \u001b[91m☒\u001b[0m 101 -1244\n", | |
"Q 73+192 T 265 \u001b[91m☒\u001b[0m 100 -165\n", | |
"Q 62+63 T 125 \u001b[91m☒\u001b[0m 621 496\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 2\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 7s 150us/step - loss: 1.7083 - acc: 0.3673 - val_loss: 1.6363 - val_acc: 0.3910\n", | |
"Q 738+405 T 1143 \u001b[91m☒\u001b[0m 1229 86\n", | |
"Q 330+89 T 419 \u001b[91m☒\u001b[0m 403 -16\n", | |
"Q 72+32 T 104 \u001b[91m☒\u001b[0m 34 -70\n", | |
"Q 16+37 T 53 \u001b[91m☒\u001b[0m 11 -42\n", | |
"Q 73+421 T 494 \u001b[92m☑\u001b[0m 494 0\n", | |
"Q 3+721 T 724 \u001b[91m☒\u001b[0m 44 -680\n", | |
"Q 649+0 T 649 \u001b[91m☒\u001b[0m 904 255\n", | |
"Q 600+745 T 1345 \u001b[91m☒\u001b[0m 1229 -116\n", | |
"Q 73+192 T 265 \u001b[91m☒\u001b[0m 409 144\n", | |
"Q 62+63 T 125 \u001b[91m☒\u001b[0m 66 -59\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 3\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 7s 153us/step - loss: 1.5615 - acc: 0.4166 - val_loss: 1.4812 - val_acc: 0.4440\n", | |
"Q 738+405 T 1143 \u001b[91m☒\u001b[0m 1100 -43\n", | |
"Q 330+89 T 419 \u001b[91m☒\u001b[0m 330 -89\n", | |
"Q 72+32 T 104 \u001b[91m☒\u001b[0m 12 -92\n", | |
"Q 16+37 T 53 \u001b[91m☒\u001b[0m 10 -43\n", | |
"Q 73+421 T 494 \u001b[91m☒\u001b[0m 488 -6\n", | |
"Q 3+721 T 724 \u001b[91m☒\u001b[0m 380 -344\n", | |
"Q 649+0 T 649 \u001b[91m☒\u001b[0m 650 1\n", | |
"Q 600+745 T 1345 \u001b[91m☒\u001b[0m 1330 -15\n", | |
"Q 73+192 T 265 \u001b[91m☒\u001b[0m 231 -34\n", | |
"Q 62+63 T 125 \u001b[91m☒\u001b[0m 126 1\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 4\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 6s 140us/step - loss: 1.4003 - acc: 0.4771 - val_loss: 1.3243 - val_acc: 0.5094\n", | |
"Q 738+405 T 1143 \u001b[91m☒\u001b[0m 1177 34\n", | |
"Q 330+89 T 419 \u001b[91m☒\u001b[0m 301 -118\n", | |
"Q 72+32 T 104 \u001b[91m☒\u001b[0m 110 6\n", | |
"Q 16+37 T 53 \u001b[91m☒\u001b[0m 10 -43\n", | |
"Q 73+421 T 494 \u001b[91m☒\u001b[0m 484 -10\n", | |
"Q 3+721 T 724 \u001b[91m☒\u001b[0m 731 7\n", | |
"Q 649+0 T 649 \u001b[91m☒\u001b[0m 661 12\n", | |
"Q 600+745 T 1345 \u001b[91m☒\u001b[0m 1329 -16\n", | |
"Q 73+192 T 265 \u001b[91m☒\u001b[0m 211 -54\n", | |
"Q 62+63 T 125 \u001b[91m☒\u001b[0m 121 -4\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 5\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 6s 138us/step - loss: 1.2541 - acc: 0.5349 - val_loss: 1.1972 - val_acc: 0.5516\n", | |
"Q 738+405 T 1143 \u001b[91m☒\u001b[0m 1111 -32\n", | |
"Q 330+89 T 419 \u001b[91m☒\u001b[0m 408 -11\n", | |
"Q 72+32 T 104 \u001b[91m☒\u001b[0m 116 12\n", | |
"Q 16+37 T 53 \u001b[91m☒\u001b[0m 61 8\n", | |
"Q 73+421 T 494 \u001b[91m☒\u001b[0m 485 -9\n", | |
"Q 3+721 T 724 \u001b[91m☒\u001b[0m 731 7\n", | |
"Q 649+0 T 649 \u001b[91m☒\u001b[0m 651 2\n", | |
"Q 600+745 T 1345 \u001b[91m☒\u001b[0m 1311 -34\n", | |
"Q 73+192 T 265 \u001b[91m☒\u001b[0m 285 20\n", | |
"Q 62+63 T 125 \u001b[91m☒\u001b[0m 121 -4\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 6\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 7s 147us/step - loss: 1.1212 - acc: 0.5880 - val_loss: 1.0827 - val_acc: 0.5884\n", | |
"Q 738+405 T 1143 \u001b[91m☒\u001b[0m 1074 -69\n", | |
"Q 330+89 T 419 \u001b[91m☒\u001b[0m 400 -19\n", | |
"Q 72+32 T 104 \u001b[91m☒\u001b[0m 102 -2\n", | |
"Q 16+37 T 53 \u001b[91m☒\u001b[0m 68 15\n", | |
"Q 73+421 T 494 \u001b[91m☒\u001b[0m 480 -14\n", | |
"Q 3+721 T 724 \u001b[91m☒\u001b[0m 711 -13\n", | |
"Q 649+0 T 649 \u001b[91m☒\u001b[0m 646 -3\n", | |
"Q 600+745 T 1345 \u001b[91m☒\u001b[0m 1300 -45\n", | |
"Q 73+192 T 265 \u001b[91m☒\u001b[0m 268 3\n", | |
"Q 62+63 T 125 \u001b[91m☒\u001b[0m 119 -6\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 7\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 6s 138us/step - loss: 1.0100 - acc: 0.6336 - val_loss: 0.9639 - val_acc: 0.6505\n", | |
"Q 738+405 T 1143 \u001b[91m☒\u001b[0m 1177 34\n", | |
"Q 330+89 T 419 \u001b[91m☒\u001b[0m 410 -9\n", | |
"Q 72+32 T 104 \u001b[91m☒\u001b[0m 11 -93\n", | |
"Q 16+37 T 53 \u001b[91m☒\u001b[0m 58 5\n", | |
"Q 73+421 T 494 \u001b[91m☒\u001b[0m 495 1\n", | |
"Q 3+721 T 724 \u001b[91m☒\u001b[0m 728 4\n", | |
"Q 649+0 T 649 \u001b[91m☒\u001b[0m 659 10\n", | |
"Q 600+745 T 1345 \u001b[91m☒\u001b[0m 1375 30\n", | |
"Q 73+192 T 265 \u001b[91m☒\u001b[0m 253 -12\n", | |
"Q 62+63 T 125 \u001b[91m☒\u001b[0m 121 -4\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 8\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 6s 144us/step - loss: 0.9125 - acc: 0.6739 - val_loss: 0.8792 - val_acc: 0.6823\n", | |
"Q 738+405 T 1143 \u001b[91m☒\u001b[0m 1141 -2\n", | |
"Q 330+89 T 419 \u001b[91m☒\u001b[0m 413 -6\n", | |
"Q 72+32 T 104 \u001b[91m☒\u001b[0m 113 9\n", | |
"Q 16+37 T 53 \u001b[92m☑\u001b[0m 53 0\n", | |
"Q 73+421 T 494 \u001b[91m☒\u001b[0m 495 1\n", | |
"Q 3+721 T 724 \u001b[91m☒\u001b[0m 723 -1\n", | |
"Q 649+0 T 649 \u001b[91m☒\u001b[0m 659 10\n", | |
"Q 600+745 T 1345 \u001b[91m☒\u001b[0m 1349 4\n", | |
"Q 73+192 T 265 \u001b[91m☒\u001b[0m 263 -2\n", | |
"Q 62+63 T 125 \u001b[91m☒\u001b[0m 121 -4\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 9\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 7s 147us/step - loss: 0.8281 - acc: 0.7077 - val_loss: 0.7971 - val_acc: 0.7215\n", | |
"Q 738+405 T 1143 \u001b[91m☒\u001b[0m 1147 4\n", | |
"Q 330+89 T 419 \u001b[91m☒\u001b[0m 418 -1\n", | |
"Q 72+32 T 104 \u001b[92m☑\u001b[0m 104 0\n", | |
"Q 16+37 T 53 \u001b[91m☒\u001b[0m 58 5\n", | |
"Q 73+421 T 494 \u001b[91m☒\u001b[0m 495 1\n", | |
"Q 3+721 T 724 \u001b[91m☒\u001b[0m 727 3\n", | |
"Q 649+0 T 649 \u001b[92m☑\u001b[0m 649 0\n", | |
"Q 600+745 T 1345 \u001b[91m☒\u001b[0m 1344 -1\n", | |
"Q 73+192 T 265 \u001b[91m☒\u001b[0m 262 -3\n", | |
"Q 62+63 T 125 \u001b[91m☒\u001b[0m 124 -1\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 10\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 6s 141us/step - loss: 0.7602 - acc: 0.7322 - val_loss: 0.7492 - val_acc: 0.7301\n", | |
"Q 738+405 T 1143 \u001b[91m☒\u001b[0m 1147 4\n", | |
"Q 330+89 T 419 \u001b[91m☒\u001b[0m 418 -1\n", | |
"Q 72+32 T 104 \u001b[91m☒\u001b[0m 10 -94\n", | |
"Q 16+37 T 53 \u001b[92m☑\u001b[0m 53 0\n", | |
"Q 73+421 T 494 \u001b[91m☒\u001b[0m 495 1\n", | |
"Q 3+721 T 724 \u001b[92m☑\u001b[0m 724 0\n", | |
"Q 649+0 T 649 \u001b[92m☑\u001b[0m 649 0\n", | |
"Q 600+745 T 1345 \u001b[91m☒\u001b[0m 1445 100\n", | |
"Q 73+192 T 265 \u001b[91m☒\u001b[0m 262 -3\n", | |
"Q 62+63 T 125 \u001b[91m☒\u001b[0m 128 3\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 11\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 7s 155us/step - loss: 0.6803 - acc: 0.7581 - val_loss: 0.6253 - val_acc: 0.7754\n", | |
"Q 738+405 T 1143 \u001b[91m☒\u001b[0m 1144 1\n", | |
"Q 330+89 T 419 \u001b[92m☑\u001b[0m 419 0\n", | |
"Q 72+32 T 104 \u001b[91m☒\u001b[0m 105 1\n", | |
"Q 16+37 T 53 \u001b[91m☒\u001b[0m 52 -1\n", | |
"Q 73+421 T 494 \u001b[91m☒\u001b[0m 493 -1\n", | |
"Q 3+721 T 724 \u001b[92m☑\u001b[0m 724 0\n", | |
"Q 649+0 T 649 \u001b[92m☑\u001b[0m 649 0\n", | |
"Q 600+745 T 1345 \u001b[91m☒\u001b[0m 1344 -1\n", | |
"Q 73+192 T 265 \u001b[91m☒\u001b[0m 263 -2\n", | |
"Q 62+63 T 125 \u001b[91m☒\u001b[0m 126 1\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 12\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 7s 165us/step - loss: 0.5427 - acc: 0.8073 - val_loss: 0.4794 - val_acc: 0.8296\n", | |
"Q 738+405 T 1143 \u001b[91m☒\u001b[0m 1142 -1\n", | |
"Q 330+89 T 419 \u001b[91m☒\u001b[0m 429 10\n", | |
"Q 72+32 T 104 \u001b[91m☒\u001b[0m 105 1\n", | |
"Q 16+37 T 53 \u001b[91m☒\u001b[0m 43 -10\n", | |
"Q 73+421 T 494 \u001b[91m☒\u001b[0m 493 -1\n", | |
"Q 3+721 T 724 \u001b[91m☒\u001b[0m 723 -1\n", | |
"Q 649+0 T 649 \u001b[92m☑\u001b[0m 649 0\n", | |
"Q 600+745 T 1345 \u001b[92m☑\u001b[0m 1345 0\n", | |
"Q 73+192 T 265 \u001b[92m☑\u001b[0m 265 0\n", | |
"Q 62+63 T 125 \u001b[91m☒\u001b[0m 126 1\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 13\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 7s 147us/step - loss: 0.3968 - acc: 0.8710 - val_loss: 0.3384 - val_acc: 0.9010\n", | |
"Q 738+405 T 1143 \u001b[92m☑\u001b[0m 1143 0\n", | |
"Q 330+89 T 419 \u001b[92m☑\u001b[0m 419 0\n", | |
"Q 72+32 T 104 \u001b[91m☒\u001b[0m 105 1\n", | |
"Q 16+37 T 53 \u001b[91m☒\u001b[0m 54 1\n", | |
"Q 73+421 T 494 \u001b[91m☒\u001b[0m 493 -1\n", | |
"Q 3+721 T 724 \u001b[91m☒\u001b[0m 723 -1\n", | |
"Q 649+0 T 649 \u001b[92m☑\u001b[0m 649 0\n", | |
"Q 600+745 T 1345 \u001b[91m☒\u001b[0m 1445 100\n", | |
"Q 73+192 T 265 \u001b[92m☑\u001b[0m 265 0\n", | |
"Q 62+63 T 125 \u001b[92m☑\u001b[0m 125 0\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 14\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 7s 150us/step - loss: 0.2833 - acc: 0.9250 - val_loss: 0.2467 - val_acc: 0.9398\n", | |
"Q 738+405 T 1143 \u001b[92m☑\u001b[0m 1143 0\n", | |
"Q 330+89 T 419 \u001b[92m☑\u001b[0m 419 0\n", | |
"Q 72+32 T 104 \u001b[92m☑\u001b[0m 104 0\n", | |
"Q 16+37 T 53 \u001b[91m☒\u001b[0m 55 2\n", | |
"Q 73+421 T 494 \u001b[92m☑\u001b[0m 494 0\n", | |
"Q 3+721 T 724 \u001b[91m☒\u001b[0m 723 -1\n", | |
"Q 649+0 T 649 \u001b[92m☑\u001b[0m 649 0\n", | |
"Q 600+745 T 1345 \u001b[92m☑\u001b[0m 1345 0\n", | |
"Q 73+192 T 265 \u001b[92m☑\u001b[0m 265 0\n", | |
"Q 62+63 T 125 \u001b[92m☑\u001b[0m 125 0\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 15\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 7s 165us/step - loss: 0.2022 - acc: 0.9557 - val_loss: 0.1814 - val_acc: 0.9583\n", | |
"Q 738+405 T 1143 \u001b[92m☑\u001b[0m 1143 0\n", | |
"Q 330+89 T 419 \u001b[92m☑\u001b[0m 419 0\n", | |
"Q 72+32 T 104 \u001b[91m☒\u001b[0m 105 1\n", | |
"Q 16+37 T 53 \u001b[92m☑\u001b[0m 53 0\n", | |
"Q 73+421 T 494 \u001b[92m☑\u001b[0m 494 0\n", | |
"Q 3+721 T 724 \u001b[92m☑\u001b[0m 724 0\n", | |
"Q 649+0 T 649 \u001b[92m☑\u001b[0m 649 0\n", | |
"Q 600+745 T 1345 \u001b[92m☑\u001b[0m 1345 0\n", | |
"Q 73+192 T 265 \u001b[92m☑\u001b[0m 265 0\n", | |
"Q 62+63 T 125 \u001b[92m☑\u001b[0m 125 0\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 16\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 7s 153us/step - loss: 0.1444 - acc: 0.9724 - val_loss: 0.1440 - val_acc: 0.9670\n", | |
"Q 738+405 T 1143 \u001b[92m☑\u001b[0m 1143 0\n", | |
"Q 330+89 T 419 \u001b[91m☒\u001b[0m 429 10\n", | |
"Q 72+32 T 104 \u001b[92m☑\u001b[0m 104 0\n", | |
"Q 16+37 T 53 \u001b[92m☑\u001b[0m 53 0\n", | |
"Q 73+421 T 494 \u001b[92m☑\u001b[0m 494 0\n", | |
"Q 3+721 T 724 \u001b[92m☑\u001b[0m 724 0\n", | |
"Q 649+0 T 649 \u001b[92m☑\u001b[0m 649 0\n", | |
"Q 600+745 T 1345 \u001b[92m☑\u001b[0m 1345 0\n", | |
"Q 73+192 T 265 \u001b[92m☑\u001b[0m 265 0\n", | |
"Q 62+63 T 125 \u001b[92m☑\u001b[0m 125 0\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 17\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 7s 154us/step - loss: 0.1099 - acc: 0.9802 - val_loss: 0.1179 - val_acc: 0.9735\n", | |
"Q 738+405 T 1143 \u001b[92m☑\u001b[0m 1143 0\n", | |
"Q 330+89 T 419 \u001b[91m☒\u001b[0m 429 10\n", | |
"Q 72+32 T 104 \u001b[92m☑\u001b[0m 104 0\n", | |
"Q 16+37 T 53 \u001b[92m☑\u001b[0m 53 0\n", | |
"Q 73+421 T 494 \u001b[92m☑\u001b[0m 494 0\n", | |
"Q 3+721 T 724 \u001b[92m☑\u001b[0m 724 0\n", | |
"Q 649+0 T 649 \u001b[92m☑\u001b[0m 649 0\n", | |
"Q 600+745 T 1345 \u001b[92m☑\u001b[0m 1345 0\n", | |
"Q 73+192 T 265 \u001b[92m☑\u001b[0m 265 0\n", | |
"Q 62+63 T 125 \u001b[92m☑\u001b[0m 125 0\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 18\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 7s 159us/step - loss: 0.0890 - acc: 0.9839 - val_loss: 0.0854 - val_acc: 0.9831\n", | |
"Q 738+405 T 1143 \u001b[92m☑\u001b[0m 1143 0\n", | |
"Q 330+89 T 419 \u001b[92m☑\u001b[0m 419 0\n", | |
"Q 72+32 T 104 \u001b[92m☑\u001b[0m 104 0\n", | |
"Q 16+37 T 53 \u001b[92m☑\u001b[0m 53 0\n", | |
"Q 73+421 T 494 \u001b[92m☑\u001b[0m 494 0\n", | |
"Q 3+721 T 724 \u001b[92m☑\u001b[0m 724 0\n", | |
"Q 649+0 T 649 \u001b[92m☑\u001b[0m 649 0\n", | |
"Q 600+745 T 1345 \u001b[92m☑\u001b[0m 1345 0\n", | |
"Q 73+192 T 265 \u001b[92m☑\u001b[0m 265 0\n", | |
"Q 62+63 T 125 \u001b[92m☑\u001b[0m 125 0\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 19\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 7s 151us/step - loss: 0.0668 - acc: 0.9899 - val_loss: 0.0694 - val_acc: 0.9871\n", | |
"Q 738+405 T 1143 \u001b[92m☑\u001b[0m 1143 0\n", | |
"Q 330+89 T 419 \u001b[92m☑\u001b[0m 419 0\n", | |
"Q 72+32 T 104 \u001b[92m☑\u001b[0m 104 0\n", | |
"Q 16+37 T 53 \u001b[92m☑\u001b[0m 53 0\n", | |
"Q 73+421 T 494 \u001b[92m☑\u001b[0m 494 0\n", | |
"Q 3+721 T 724 \u001b[92m☑\u001b[0m 724 0\n", | |
"Q 649+0 T 649 \u001b[92m☑\u001b[0m 649 0\n", | |
"Q 600+745 T 1345 \u001b[92m☑\u001b[0m 1345 0\n", | |
"Q 73+192 T 265 \u001b[92m☑\u001b[0m 265 0\n", | |
"Q 62+63 T 125 \u001b[92m☑\u001b[0m 125 0\n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 20\n", | |
"Train on 45000 samples, validate on 5000 samples\n", | |
"Epoch 1/1\n", | |
"45000/45000 [==============================] - 7s 160us/step - loss: 0.0619 - acc: 0.9888 - val_loss: 0.0653 - val_acc: 0.9859\n", | |
"Q 738+405 T 1143 \u001b[92m☑\u001b[0m 1143 0\n", | |
"Q 330+89 T 419 \u001b[92m☑\u001b[0m 419 0\n", | |
"Q 72+32 T 104 \u001b[92m☑\u001b[0m 104 0\n", | |
"Q 16+37 T 53 \u001b[92m☑\u001b[0m 53 0\n", | |
"Q 73+421 T 494 \u001b[92m☑\u001b[0m 494 0\n", | |
"Q 3+721 T 724 \u001b[92m☑\u001b[0m 724 0\n", | |
"Q 649+0 T 649 \u001b[92m☑\u001b[0m 649 0\n", | |
"Q 600+745 T 1345 \u001b[92m☑\u001b[0m 1345 0\n", | |
"Q 73+192 T 265 \u001b[92m☑\u001b[0m 265 0\n", | |
"Q 62+63 T 125 \u001b[92m☑\u001b[0m 125 0\n" | |
] | |
} | |
], | |
"source": [ | |
"# Step - 6: Train the model each generation and show predictions against the validation dataset.\n", | |
"\n", | |
"# We will collect 10 random questions and check the prediction status on these after every few epoch\n", | |
"question_indices = np.random.randint(0, len(x_test), size=10)\n", | |
"\n", | |
"num_iterations = 20\n", | |
"learning_error_map = np.ndarray((10,num_iterations), np.int)\n", | |
"\n", | |
"for iteration in range(1, num_iterations+1):\n", | |
" print()\n", | |
" print('-' * 50)\n", | |
" print('Iteration', iteration)\n", | |
" model.fit(x_train, y_train,\n", | |
" batch_size=BATCH_SIZE,\n", | |
" epochs=1,\n", | |
" validation_data=(x_test, y_test),\n", | |
" callbacks=[])\n", | |
" \n", | |
" # Select 10 samples from the validation set at random so we can visualize\n", | |
" # errors.\n", | |
" questionNumber = 0\n", | |
" for ind in question_indices:\n", | |
" #ind = np.random.randint(0, len(x_val))\n", | |
" rowx, rowy = x_test[np.array([ind])], y_test[np.array([ind])]\n", | |
" preds = model.predict_classes(rowx, verbose=2)\n", | |
"\n", | |
" q = ctable.decode(rowx[0])\n", | |
" correct = ctable.decode(rowy[0])\n", | |
" guess = ctable.decode(preds[0], calc_argmax=False)\n", | |
" error = int(guess) - int(correct)\n", | |
" \n", | |
" print('Q', q[::-1] if REVERSE else q, end=' ') \n", | |
" print('T', correct, end=' ')\n", | |
" if correct == guess:\n", | |
" print(colors.ok + '☑' + colors.close, end=' ')\n", | |
" else:\n", | |
" print(colors.fail + '☒' + colors.close, end=' ')\n", | |
" print(guess, end=' ')\n", | |
" print(error)\n", | |
" learning_error_map[questionNumber][iteration-1] = error\n", | |
" questionNumber = questionNumber+1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<matplotlib.lines.Line2D at 0x5f18fd0>,\n", | |
" <matplotlib.lines.Line2D at 0x5f201d0>,\n", | |
" <matplotlib.lines.Line2D at 0x5f20320>,\n", | |
" <matplotlib.lines.Line2D at 0x5f20470>,\n", | |
" <matplotlib.lines.Line2D at 0x5f205c0>,\n", | |
" <matplotlib.lines.Line2D at 0x5f20710>,\n", | |
" <matplotlib.lines.Line2D at 0x5f20860>,\n", | |
" <matplotlib.lines.Line2D at 0x5f209b0>,\n", | |
" <matplotlib.lines.Line2D at 0x5f20b00>,\n", | |
" <matplotlib.lines.Line2D at 0x5f20c50>,\n", | |
" <matplotlib.lines.Line2D at 0x5ef7b38>,\n", | |
" <matplotlib.lines.Line2D at 0x5f20eb8>,\n", | |
" <matplotlib.lines.Line2D at 0x5f26048>,\n", | |
" <matplotlib.lines.Line2D at 0x5f26198>,\n", | |
" <matplotlib.lines.Line2D at 0x5f262e8>,\n", | |
" <matplotlib.lines.Line2D at 0x5f26438>,\n", | |
" <matplotlib.lines.Line2D at 0x5f26588>,\n", | |
" <matplotlib.lines.Line2D at 0x5f266d8>,\n", | |
" <matplotlib.lines.Line2D at 0x5f26828>,\n", | |
" <matplotlib.lines.Line2D at 0x5f26978>]" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 864x504 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.rcParams['figure.figsize'] = [12, 7]\n", | |
"sel_ques = [questions[i] for i in question_indices]\n", | |
"plt.plot(sel_ques,learning_error_map)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 69, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def evaluateModel(x_data, y_data) :\n", | |
" sum_of_errors = 0\n", | |
" max_of_errors = 0\n", | |
" min_of_errors = np.inf\n", | |
"\n", | |
" max_error_indx = -1\n", | |
" min_error_indx = -1\n", | |
"\n", | |
" sum_of_correct = 0\n", | |
"\n", | |
" correct_count = 0\n", | |
"\n", | |
" for ind in range(0, len(x_data)):\n", | |
" #ind = np.random.randint(0, len(x_val))\n", | |
"\n", | |
" rowx, rowy = x_data[np.array([ind])], y_data[np.array([ind])]\n", | |
" preds = model.predict_classes(rowx, verbose=2)\n", | |
" q = ctable.decode(rowx[0])\n", | |
" q = q[::-1] if REVERSE else q\n", | |
"\n", | |
" correct = ctable.decode(rowy[0]) \n", | |
" guess = ctable.decode(preds[0], calc_argmax=False)\n", | |
" error = np.absolute(int(guess) - int(correct))\n", | |
" sum_of_errors = sum_of_errors + error\n", | |
" sum_of_correct = sum_of_correct + int(correct)\n", | |
"\n", | |
" if error > max_of_errors :\n", | |
" max_of_errors = error\n", | |
" max_error_indx = ind\n", | |
"\n", | |
" if error > 0 and error < min_of_errors:\n", | |
" min_of_errors = error\n", | |
" min_error_indx = ind\n", | |
"\n", | |
" if guess == correct:\n", | |
" correct_count = correct_count + 1\n", | |
" \n", | |
" if ind % 1000 == 1:\n", | |
" print(\"[{}] : Errors: Max {}, Min {}, Sum {}, Correct_sum {}, correct: {}/{}, Accuracy: {}\".format(ind, max_of_errors, min_of_errors, sum_of_errors, sum_of_correct, correct_count, len(x_data), correct_count/len(x_data)))\n", | |
"\n", | |
" print(\"Errors: Max {}, Min {}, Sum {}, Correct_sum {}, correct: {}/{}, Accuracy: {}\".format(max_of_errors, min_of_errors, sum_of_errors, sum_of_correct, correct_count, len(x_data), correct_count/len(x_data)))\n", | |
"\n", | |
" return {\n", | |
" \"max_error\" : max_of_errors, \n", | |
" \"max_error_index\" : max_error_indx,\n", | |
" \"min_error\" : min_of_errors,\n", | |
" \"min_error_index\" : min_error_indx,\n", | |
" \"correct\" : correct_count,\n", | |
" \"wrong\" : len(x_data) - correct_count,\n", | |
" \"accuracy\" : str(correct_count/len(x_data) * 100) + \"%\"\n", | |
" }" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 70, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[1] : Errors: Max 10, Min 10, Sum 10, Correct_sum 1251, correct: 1/45000, Accuracy: 2.2222222222222223e-05\n", | |
"[1001] : Errors: Max 100, Min 1, Sum 565, Correct_sum 660612, correct: 942/45000, Accuracy: 0.020933333333333335\n", | |
"[2001] : Errors: Max 100, Min 1, Sum 1292, Correct_sum 1327546, correct: 1894/45000, Accuracy: 0.04208888888888889\n", | |
"[3001] : Errors: Max 8860, Min 1, Sum 10982, Correct_sum 2004414, correct: 2843/45000, Accuracy: 0.06317777777777778\n", | |
"[4001] : Errors: Max 8860, Min 1, Sum 12746, Correct_sum 2677280, correct: 3794/45000, Accuracy: 0.08431111111111111\n", | |
"[5001] : Errors: Max 8860, Min 1, Sum 13775, Correct_sum 3343505, correct: 4741/45000, Accuracy: 0.10535555555555555\n", | |
"[6001] : Errors: Max 8860, Min 1, Sum 14963, Correct_sum 4018502, correct: 5682/45000, Accuracy: 0.12626666666666667\n", | |
"[7001] : Errors: Max 8860, Min 1, Sum 16149, Correct_sum 4708606, correct: 6630/45000, Accuracy: 0.14733333333333334\n", | |
"[8001] : Errors: Max 8860, Min 1, Sum 16730, Correct_sum 5380712, correct: 7585/45000, Accuracy: 0.16855555555555554\n", | |
"[9001] : Errors: Max 8860, Min 1, Sum 25849, Correct_sum 6062012, correct: 8533/45000, Accuracy: 0.18962222222222222\n", | |
"[10001] : Errors: Max 8860, Min 1, Sum 34510, Correct_sum 6720593, correct: 9490/45000, Accuracy: 0.21088888888888888\n", | |
"[11001] : Errors: Max 8860, Min 1, Sum 42938, Correct_sum 7387953, correct: 10436/45000, Accuracy: 0.23191111111111112\n", | |
"[12001] : Errors: Max 8860, Min 1, Sum 52080, Correct_sum 8062013, correct: 11386/45000, Accuracy: 0.2530222222222222\n", | |
"[13001] : Errors: Max 8860, Min 1, Sum 53801, Correct_sum 8727100, correct: 12334/45000, Accuracy: 0.2740888888888889\n", | |
"[14001] : Errors: Max 8860, Min 1, Sum 62823, Correct_sum 9393554, correct: 13278/45000, Accuracy: 0.29506666666666664\n", | |
"[15001] : Errors: Max 8860, Min 1, Sum 63209, Correct_sum 10046954, correct: 14231/45000, Accuracy: 0.31624444444444444\n", | |
"[16001] : Errors: Max 8860, Min 1, Sum 65346, Correct_sum 10729773, correct: 15182/45000, Accuracy: 0.3373777777777778\n", | |
"[17001] : Errors: Max 8860, Min 1, Sum 66728, Correct_sum 11386352, correct: 16136/45000, Accuracy: 0.3585777777777778\n", | |
"[18001] : Errors: Max 8860, Min 1, Sum 67716, Correct_sum 12053681, correct: 17072/45000, Accuracy: 0.37937777777777776\n", | |
"[19001] : Errors: Max 8860, Min 1, Sum 68147, Correct_sum 12733589, correct: 18033/45000, Accuracy: 0.40073333333333333\n", | |
"[20001] : Errors: Max 8860, Min 1, Sum 68831, Correct_sum 13419522, correct: 18981/45000, Accuracy: 0.4218\n", | |
"[21001] : Errors: Max 8860, Min 1, Sum 85288, Correct_sum 14102391, correct: 19934/45000, Accuracy: 0.4429777777777778\n", | |
"[22001] : Errors: Max 8860, Min 1, Sum 101838, Correct_sum 14782431, correct: 20884/45000, Accuracy: 0.4640888888888889\n", | |
"[23001] : Errors: Max 8860, Min 1, Sum 102150, Correct_sum 15434826, correct: 21836/45000, Accuracy: 0.4852444444444444\n", | |
"[24001] : Errors: Max 8860, Min 1, Sum 102771, Correct_sum 16086003, correct: 22797/45000, Accuracy: 0.5066\n", | |
"[25001] : Errors: Max 8860, Min 1, Sum 111578, Correct_sum 16765754, correct: 23744/45000, Accuracy: 0.5276444444444445\n", | |
"[26001] : Errors: Max 8860, Min 1, Sum 120587, Correct_sum 17460192, correct: 24674/45000, Accuracy: 0.5483111111111111\n", | |
"[27001] : Errors: Max 8860, Min 1, Sum 121203, Correct_sum 18116304, correct: 25620/45000, Accuracy: 0.5693333333333334\n", | |
"[28001] : Errors: Max 8860, Min 1, Sum 121998, Correct_sum 18778038, correct: 26573/45000, Accuracy: 0.5905111111111111\n", | |
"[29001] : Errors: Max 8860, Min 1, Sum 122528, Correct_sum 19434244, correct: 27535/45000, Accuracy: 0.6118888888888889\n", | |
"[30001] : Errors: Max 8860, Min 1, Sum 130983, Correct_sum 20116135, correct: 28481/45000, Accuracy: 0.6329111111111111\n", | |
"[31001] : Errors: Max 8860, Min 1, Sum 156412, Correct_sum 20793751, correct: 29435/45000, Accuracy: 0.6541111111111111\n", | |
"[32001] : Errors: Max 8860, Min 1, Sum 165446, Correct_sum 21470656, correct: 30380/45000, Accuracy: 0.6751111111111111\n", | |
"[33001] : Errors: Max 8860, Min 1, Sum 174136, Correct_sum 22129274, correct: 31334/45000, Accuracy: 0.6963111111111111\n", | |
"[34001] : Errors: Max 8860, Min 1, Sum 175203, Correct_sum 22792720, correct: 32279/45000, Accuracy: 0.7173111111111111\n", | |
"[35001] : Errors: Max 8860, Min 1, Sum 183506, Correct_sum 23470643, correct: 33242/45000, Accuracy: 0.7387111111111111\n", | |
"[36001] : Errors: Max 8860, Min 1, Sum 192091, Correct_sum 24155242, correct: 34194/45000, Accuracy: 0.7598666666666667\n", | |
"[37001] : Errors: Max 8860, Min 1, Sum 200558, Correct_sum 24805652, correct: 35140/45000, Accuracy: 0.7808888888888889\n", | |
"[38001] : Errors: Max 8860, Min 1, Sum 209300, Correct_sum 25473790, correct: 36087/45000, Accuracy: 0.8019333333333334\n", | |
"[39001] : Errors: Max 8860, Min 1, Sum 209944, Correct_sum 26158820, correct: 37028/45000, Accuracy: 0.8228444444444445\n", | |
"[40001] : Errors: Max 8860, Min 1, Sum 210927, Correct_sum 26815289, correct: 37986/45000, Accuracy: 0.8441333333333333\n", | |
"[41001] : Errors: Max 8860, Min 1, Sum 219803, Correct_sum 27489924, correct: 38927/45000, Accuracy: 0.8650444444444444\n", | |
"[42001] : Errors: Max 8860, Min 1, Sum 220199, Correct_sum 28163650, correct: 39884/45000, Accuracy: 0.8863111111111112\n", | |
"[43001] : Errors: Max 8860, Min 1, Sum 229090, Correct_sum 28835054, correct: 40833/45000, Accuracy: 0.9074\n", | |
"[44001] : Errors: Max 8860, Min 1, Sum 245933, Correct_sum 29515309, correct: 41789/45000, Accuracy: 0.9286444444444445\n", | |
"Errors: Max 8860, Min 1, Sum 254355, Correct_sum 30181321, correct: 42748/45000, Accuracy: 0.9499555555555556\n", | |
"Train data stats: {} {'max_error': 8860, 'max_error_index': 2554, 'min_error': 1, 'min_error_index': 8, 'correct': 42748, 'wrong': 2252, 'accuracy': '94.99555555555555%'}\n", | |
"-------------------------------------------------------\n", | |
"[1] : Errors: Max 0, Min inf, Sum 0, Correct_sum 1059, correct: 2/5000, Accuracy: 0.0004\n", | |
"[1001] : Errors: Max 100, Min 1, Sum 1168, Correct_sum 669355, correct: 944/5000, Accuracy: 0.1888\n", | |
"[2001] : Errors: Max 100, Min 1, Sum 2351, Correct_sum 1341388, correct: 1877/5000, Accuracy: 0.3754\n", | |
"[3001] : Errors: Max 101, Min 1, Sum 3673, Correct_sum 2037160, correct: 2801/5000, Accuracy: 0.5602\n", | |
"[4001] : Errors: Max 8050, Min 1, Sum 12613, Correct_sum 2728724, correct: 3730/5000, Accuracy: 0.746\n", | |
"Errors: Max 8050, Min 1, Sum 13596, Correct_sum 3390692, correct: 4677/5000, Accuracy: 0.9354\n", | |
"Train data stats: {} {'max_error': 8050, 'max_error_index': 3423, 'min_error': 1, 'min_error_index': 4, 'correct': 4677, 'wrong': 323, 'accuracy': '93.54%'}\n" | |
] | |
} | |
], | |
"source": [ | |
"train_eval = evaluateModel(x_train, y_train)\n", | |
"print(\"Train data stats: {}\", train_eval)\n", | |
"print(\"-\"*55)\n", | |
"test_eval = evaluateModel(x_test, y_test)\n", | |
"print(\"Train data stats: {}\", test_eval)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 154, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def printQuestionDetails(x_test, y_test, indx) :\n", | |
" rowx, rowy = x_test[np.array([indx])], y_test[np.array([indx])]\n", | |
" preds = model.predict_classes(rowx, verbose=2)\n", | |
" q = ctable.decode(rowx[0])\n", | |
" q = q[::-1] if REVERSE else q\n", | |
"\n", | |
" correct = ctable.decode(rowy[0])\n", | |
" guess = ctable.decode(preds[0], calc_argmax=False)\n", | |
" error = int(guess) - int(correct) \n", | |
" \n", | |
" print('Q', q[::-1] if REVERSE else q, end=' ') \n", | |
" print('T', correct, end=' ')\n", | |
" if correct == guess:\n", | |
" print(colors.ok + '☑' + colors.close, end=' ')\n", | |
" else:\n", | |
" print(colors.fail + '☒' + colors.close, end=' ')\n", | |
" print(guess, end=' ')\n", | |
" print(error)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 155, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Q 8+501 T 113 \u001b[92m☑\u001b[0m 113 0\n" | |
] | |
} | |
], | |
"source": [ | |
"printQuestionDetails(x_test, y_test, max_error_indx)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 156, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Q 78+43 T 121 \u001b[92m☑\u001b[0m 121 0\n" | |
] | |
} | |
], | |
"source": [ | |
"printQuestionDetails(x_test, y_test, min_error_indx)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 157, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def calculate_add(a, b) :\n", | |
" q = '{}+{}'.format(a, b)\n", | |
" q = q[::-1] if REVERSE else q\n", | |
" \n", | |
" query = q + ' ' * (MAXLEN - len(q))\n", | |
" #print(query)\n", | |
" query_enc = ctable.encode(query, MAXLEN)>0\n", | |
" #print(query_enc) \n", | |
" query_enc = query_enc.reshape((1, query_enc.shape[0], query_enc.shape[1]))\n", | |
" #print(query_enc)\n", | |
" #print(\"-\"*55)\n", | |
" #print(x_train[np.array([10])].shape)\n", | |
" \n", | |
" key = tuple(sorted((a, b)))\n", | |
" if key in seen:\n", | |
" print('Seen:', end=' ')\n", | |
" \n", | |
" preds = model.predict_classes(query_enc, verbose=2) \n", | |
"\n", | |
" correct = str(a + b)\n", | |
" guess = ctable.decode(preds[0], calc_argmax=False) \n", | |
" error = int(guess) - int(correct)\n", | |
" print('Q:(', q[::-1] if REVERSE else q, ')', end=' ') \n", | |
" print('T:(', correct, ')', end=' ')\n", | |
" \n", | |
" if int(correct) == int(guess):\n", | |
" print(colors.ok + '☑' + colors.close, end=' ')\n", | |
" else:\n", | |
" print(colors.fail + '☒' + colors.close, end=' ')\n", | |
" print('G:(', guess, ')', end=' ')\n", | |
" print('E:(', error, ')')\n", | |
" \n", | |
" return {\n", | |
" \"error\" : error > 0, \n", | |
" \"seen\" : key in seen\n", | |
" }" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 159, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Q:( 388+791 ) T:( 1179 ) \u001b[92m☑\u001b[0m G:( 1179 ) E:( 0 )\n", | |
"Q:( 378+83 ) T:( 461 ) \u001b[91m☒\u001b[0m G:( 1593 ) E:( 1132 )\n", | |
"Q:( 677+562 ) T:( 1239 ) \u001b[92m☑\u001b[0m G:( 1239 ) E:( 0 )\n", | |
"Q:( 158+562 ) T:( 720 ) \u001b[92m☑\u001b[0m G:( 720 ) E:( 0 )\n", | |
"Q:( 901+829 ) T:( 1730 ) \u001b[91m☒\u001b[0m G:( 1720 ) E:( -10 )\n", | |
"Seen: Q:( 357+249 ) T:( 606 ) \u001b[92m☑\u001b[0m G:( 606 ) E:( 0 )\n", | |
"Q:( 852+967 ) T:( 1819 ) \u001b[92m☑\u001b[0m G:( 1819 ) E:( 0 )\n", | |
"Q:( 257+511 ) T:( 768 ) \u001b[92m☑\u001b[0m G:( 768 ) E:( 0 )\n", | |
"Q:( 328+674 ) T:( 1002 ) \u001b[92m☑\u001b[0m G:( 1002 ) E:( 0 )\n", | |
"Q:( 951+331 ) T:( 1282 ) \u001b[92m☑\u001b[0m G:( 1282 ) E:( 0 )\n", | |
"Q:( 747+56 ) T:( 803 ) \u001b[91m☒\u001b[0m G:( 1483 ) E:( 680 )\n", | |
"Q:( 959+794 ) T:( 1753 ) \u001b[91m☒\u001b[0m G:( 1743 ) E:( -10 )\n", | |
"Q:( 477+934 ) T:( 1411 ) \u001b[92m☑\u001b[0m G:( 1411 ) E:( 0 )\n", | |
"Q:( 769+396 ) T:( 1165 ) \u001b[91m☒\u001b[0m G:( 1155 ) E:( -10 )\n", | |
"Q:( 524+186 ) T:( 710 ) \u001b[92m☑\u001b[0m G:( 710 ) E:( 0 )\n", | |
"Q:( 338+210 ) T:( 548 ) \u001b[92m☑\u001b[0m G:( 548 ) E:( 0 )\n", | |
"Q:( 894+306 ) T:( 1200 ) \u001b[92m☑\u001b[0m G:( 1200 ) E:( 0 )\n", | |
"Q:( 137+32 ) T:( 169 ) \u001b[91m☒\u001b[0m G:( 1130 ) E:( 961 )\n", | |
"Q:( 912+264 ) T:( 1176 ) \u001b[92m☑\u001b[0m G:( 1176 ) E:( 0 )\n", | |
"Q:( 621+196 ) T:( 817 ) \u001b[92m☑\u001b[0m G:( 817 ) E:( 0 )\n", | |
"Q:( 464+898 ) T:( 1362 ) \u001b[92m☑\u001b[0m G:( 1362 ) E:( 0 )\n", | |
"Q:( 653+948 ) T:( 1601 ) \u001b[92m☑\u001b[0m G:( 1601 ) E:( 0 )\n", | |
"Q:( 959+289 ) T:( 1248 ) \u001b[92m☑\u001b[0m G:( 1248 ) E:( 0 )\n", | |
"Q:( 825+805 ) T:( 1630 ) \u001b[92m☑\u001b[0m G:( 1630 ) E:( 0 )\n", | |
"Q:( 569+718 ) T:( 1287 ) \u001b[92m☑\u001b[0m G:( 1287 ) E:( 0 )\n", | |
"Q:( 688+476 ) T:( 1164 ) \u001b[92m☑\u001b[0m G:( 1164 ) E:( 0 )\n", | |
"Q:( 62+950 ) T:( 1012 ) \u001b[91m☒\u001b[0m G:( 1602 ) E:( 590 )\n", | |
"Q:( 640+641 ) T:( 1281 ) \u001b[92m☑\u001b[0m G:( 1281 ) E:( 0 )\n", | |
"Seen: Q:( 694+65 ) T:( 759 ) \u001b[91m☒\u001b[0m G:( 1420 ) E:( 661 )\n", | |
"Q:( 419+281 ) T:( 700 ) \u001b[91m☒\u001b[0m G:( 600 ) E:( -100 )\n", | |
"Q:( 194+26 ) T:( 220 ) \u001b[91m☒\u001b[0m G:( 1261 ) E:( 1041 )\n", | |
"Q:( 551+526 ) T:( 1077 ) \u001b[92m☑\u001b[0m G:( 1077 ) E:( 0 )\n", | |
"Q:( 398+601 ) T:( 999 ) \u001b[91m☒\u001b[0m G:( 109 ) E:( -890 )\n", | |
"Q:( 226+285 ) T:( 511 ) \u001b[92m☑\u001b[0m G:( 511 ) E:( 0 )\n", | |
"Q:( 307+488 ) T:( 795 ) \u001b[92m☑\u001b[0m G:( 795 ) E:( 0 )\n", | |
"Q:( 234+924 ) T:( 1158 ) \u001b[92m☑\u001b[0m G:( 1158 ) E:( 0 )\n", | |
"Q:( 192+463 ) T:( 655 ) \u001b[92m☑\u001b[0m G:( 655 ) E:( 0 )\n", | |
"Q:( 875+544 ) T:( 1419 ) \u001b[92m☑\u001b[0m G:( 1419 ) E:( 0 )\n", | |
"Q:( 185+989 ) T:( 1174 ) \u001b[91m☒\u001b[0m G:( 1164 ) E:( -10 )\n", | |
"Seen: Q:( 624+40 ) T:( 664 ) \u001b[91m☒\u001b[0m G:( 1255 ) E:( 591 )\n", | |
"Q:( 184+988 ) T:( 1172 ) \u001b[92m☑\u001b[0m G:( 1172 ) E:( 0 )\n", | |
"Q:( 161+932 ) T:( 1093 ) \u001b[92m☑\u001b[0m G:( 1093 ) E:( 0 )\n", | |
"Q:( 589+563 ) T:( 1152 ) \u001b[92m☑\u001b[0m G:( 1152 ) E:( 0 )\n", | |
"Q:( 406+990 ) T:( 1396 ) \u001b[92m☑\u001b[0m G:( 1396 ) E:( 0 )\n", | |
"Q:( 425+153 ) T:( 578 ) \u001b[92m☑\u001b[0m G:( 578 ) E:( 0 )\n", | |
"Q:( 183+272 ) T:( 455 ) \u001b[92m☑\u001b[0m G:( 455 ) E:( 0 )\n", | |
"Q:( 488+608 ) T:( 1096 ) \u001b[92m☑\u001b[0m G:( 1096 ) E:( 0 )\n", | |
"Q:( 485+300 ) T:( 785 ) \u001b[92m☑\u001b[0m G:( 785 ) E:( 0 )\n", | |
"Q:( 480+424 ) T:( 904 ) \u001b[91m☒\u001b[0m G:( 804 ) E:( -100 )\n", | |
"Q:( 848+19 ) T:( 867 ) \u001b[91m☒\u001b[0m G:( 1278 ) E:( 411 )\n", | |
"Q:( 69+616 ) T:( 685 ) \u001b[91m☒\u001b[0m G:( 1275 ) E:( 590 )\n", | |
"Q:( 962+66 ) T:( 1028 ) \u001b[91m☒\u001b[0m G:( 1637 ) E:( 609 )\n", | |
"Q:( 859+934 ) T:( 1793 ) \u001b[92m☑\u001b[0m G:( 1793 ) E:( 0 )\n", | |
"Q:( 577+390 ) T:( 967 ) \u001b[92m☑\u001b[0m G:( 967 ) E:( 0 )\n", | |
"Q:( 514+486 ) T:( 1000 ) \u001b[91m☒\u001b[0m G:( 990 ) E:( -10 )\n", | |
"Q:( 341+845 ) T:( 1186 ) \u001b[92m☑\u001b[0m G:( 1186 ) E:( 0 )\n", | |
"Q:( 447+867 ) T:( 1314 ) \u001b[92m☑\u001b[0m G:( 1314 ) E:( 0 )\n", | |
"Q:( 156+246 ) T:( 402 ) \u001b[92m☑\u001b[0m G:( 402 ) E:( 0 )\n", | |
"Q:( 409+349 ) T:( 758 ) \u001b[92m☑\u001b[0m G:( 758 ) E:( 0 )\n", | |
"Q:( 453+718 ) T:( 1171 ) \u001b[92m☑\u001b[0m G:( 1171 ) E:( 0 )\n", | |
"Q:( 531+321 ) T:( 852 ) \u001b[92m☑\u001b[0m G:( 852 ) E:( 0 )\n", | |
"Seen: Q:( 186+12 ) T:( 198 ) \u001b[91m☒\u001b[0m G:( 1130 ) E:( 932 )\n", | |
"Q:( 435+819 ) T:( 1254 ) \u001b[92m☑\u001b[0m G:( 1254 ) E:( 0 )\n", | |
"Q:( 662+513 ) T:( 1175 ) \u001b[92m☑\u001b[0m G:( 1175 ) E:( 0 )\n", | |
"Q:( 267+167 ) T:( 434 ) \u001b[92m☑\u001b[0m G:( 434 ) E:( 0 )\n", | |
"Q:( 502+737 ) T:( 1239 ) \u001b[92m☑\u001b[0m G:( 1239 ) E:( 0 )\n", | |
"Q:( 187+688 ) T:( 875 ) \u001b[92m☑\u001b[0m G:( 875 ) E:( 0 )\n", | |
"Q:( 803+871 ) T:( 1674 ) \u001b[92m☑\u001b[0m G:( 1674 ) E:( 0 )\n", | |
"Seen: Q:( 96+174 ) T:( 270 ) \u001b[91m☒\u001b[0m G:( 770 ) E:( 500 )\n", | |
"Q:( 193+576 ) T:( 769 ) \u001b[92m☑\u001b[0m G:( 769 ) E:( 0 )\n", | |
"Q:( 821+965 ) T:( 1786 ) \u001b[92m☑\u001b[0m G:( 1786 ) E:( 0 )\n", | |
"Seen: Q:( 387+6 ) T:( 393 ) \u001b[91m☒\u001b[0m G:( 1013 ) E:( 620 )\n", | |
"Q:( 575+783 ) T:( 1358 ) \u001b[92m☑\u001b[0m G:( 1358 ) E:( 0 )\n", | |
"Q:( 657+457 ) T:( 1114 ) \u001b[92m☑\u001b[0m G:( 1114 ) E:( 0 )\n", | |
"Q:( 727+250 ) T:( 977 ) \u001b[92m☑\u001b[0m G:( 977 ) E:( 0 )\n", | |
"Q:( 515+845 ) T:( 1360 ) \u001b[91m☒\u001b[0m G:( 1350 ) E:( -10 )\n", | |
"Q:( 392+104 ) T:( 496 ) \u001b[92m☑\u001b[0m G:( 496 ) E:( 0 )\n", | |
"Q:( 296+998 ) T:( 1294 ) \u001b[91m☒\u001b[0m G:( 1285 ) E:( -9 )\n", | |
"Q:( 905+113 ) T:( 1018 ) \u001b[92m☑\u001b[0m G:( 1018 ) E:( 0 )\n", | |
"Q:( 766+154 ) T:( 920 ) \u001b[92m☑\u001b[0m G:( 920 ) E:( 0 )\n", | |
"Q:( 666+534 ) T:( 1200 ) \u001b[91m☒\u001b[0m G:( 1100 ) E:( -100 )\n", | |
"Q:( 193+991 ) T:( 1184 ) \u001b[92m☑\u001b[0m G:( 1184 ) E:( 0 )\n", | |
"Q:( 899+728 ) T:( 1627 ) \u001b[92m☑\u001b[0m G:( 1627 ) E:( 0 )\n", | |
"Q:( 430+130 ) T:( 560 ) \u001b[92m☑\u001b[0m G:( 560 ) E:( 0 )\n", | |
"Q:( 996+29 ) T:( 1025 ) \u001b[91m☒\u001b[0m G:( 1434 ) E:( 409 )\n", | |
"Q:( 557+281 ) T:( 838 ) \u001b[92m☑\u001b[0m G:( 838 ) E:( 0 )\n", | |
"Q:( 934+967 ) T:( 1901 ) \u001b[92m☑\u001b[0m G:( 1901 ) E:( 0 )\n", | |
"Q:( 197+596 ) T:( 793 ) \u001b[91m☒\u001b[0m G:( 783 ) E:( -10 )\n", | |
"Seen: Q:( 998+504 ) T:( 1502 ) \u001b[92m☑\u001b[0m G:( 1502 ) E:( 0 )\n", | |
"Q:( 406+672 ) T:( 1078 ) \u001b[92m☑\u001b[0m G:( 1078 ) E:( 0 )\n", | |
"Q:( 421+397 ) T:( 818 ) \u001b[92m☑\u001b[0m G:( 818 ) E:( 0 )\n", | |
"Q:( 698+428 ) T:( 1126 ) \u001b[92m☑\u001b[0m G:( 1126 ) E:( 0 )\n", | |
"Q:( 168+443 ) T:( 611 ) \u001b[92m☑\u001b[0m G:( 611 ) E:( 0 )\n", | |
"Q:( 891+653 ) T:( 1544 ) \u001b[92m☑\u001b[0m G:( 1544 ) E:( 0 )\n", | |
"Q:( 18+325 ) T:( 343 ) \u001b[91m☒\u001b[0m G:( 843 ) E:( 500 )\n", | |
"Seen: Q:( 2+791 ) T:( 793 ) \u001b[91m☒\u001b[0m G:( 1453 ) E:( 660 )\n", | |
"Q:( 431+593 ) T:( 1024 ) \u001b[92m☑\u001b[0m G:( 1024 ) E:( 0 )\n", | |
"Q:( 582+671 ) T:( 1253 ) \u001b[92m☑\u001b[0m G:( 1253 ) E:( 0 )\n", | |
"Q:( 810+80 ) T:( 890 ) \u001b[91m☒\u001b[0m G:( 130 ) E:( -760 )\n", | |
"Q:( 412+131 ) T:( 543 ) \u001b[91m☒\u001b[0m G:( 542 ) E:( -1 )\n", | |
"Errors: 16/100. Seen: 8\n" | |
] | |
} | |
], | |
"source": [ | |
"errors = 0\n", | |
"seens = 0\n", | |
"num_trials = 100\n", | |
"for i in range(0, num_trials):\n", | |
" res = calculate_add(np.random.randint(0, 999), np.random.randint(0, 999))\n", | |
" errors = errors + res['error']\n", | |
" seens = seens + (res['seen'] != False)\n", | |
"print(\"Errors: {}/{}. Seen: {}\".format(errors, num_trials, seens))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 173, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Seen: Q:( 0+791 ) T:( 791 ) \u001b[91m☒\u001b[0m G:( 1451 ) E:( 660 )\n", | |
"Seen: Q:( 1+791 ) T:( 792 ) \u001b[91m☒\u001b[0m G:( 1452 ) E:( 660 )\n", | |
"Seen: Q:( 2+791 ) T:( 793 ) \u001b[91m☒\u001b[0m G:( 1453 ) E:( 660 )\n", | |
"Seen: Q:( 3+791 ) T:( 794 ) \u001b[91m☒\u001b[0m G:( 1454 ) E:( 660 )\n", | |
"Seen: Q:( 4+791 ) T:( 795 ) \u001b[91m☒\u001b[0m G:( 1455 ) E:( 660 )\n", | |
"Seen: Q:( 5+791 ) T:( 796 ) \u001b[91m☒\u001b[0m G:( 1456 ) E:( 660 )\n", | |
"Seen: Q:( 6+791 ) T:( 797 ) \u001b[91m☒\u001b[0m G:( 1457 ) E:( 660 )\n", | |
"Seen: Q:( 7+791 ) T:( 798 ) \u001b[91m☒\u001b[0m G:( 1458 ) E:( 660 )\n", | |
"Seen: Q:( 8+791 ) T:( 799 ) \u001b[91m☒\u001b[0m G:( 1469 ) E:( 670 )\n", | |
"Seen: Q:( 9+791 ) T:( 800 ) \u001b[91m☒\u001b[0m G:( 1360 ) E:( 560 )\n", | |
"Q:( 10+791 ) T:( 801 ) \u001b[91m☒\u001b[0m G:( 1391 ) E:( 590 )\n", | |
"Q:( 11+791 ) T:( 802 ) \u001b[91m☒\u001b[0m G:( 1392 ) E:( 590 )\n", | |
"Q:( 12+791 ) T:( 803 ) \u001b[91m☒\u001b[0m G:( 1393 ) E:( 590 )\n", | |
"Seen: Q:( 13+791 ) T:( 804 ) \u001b[91m☒\u001b[0m G:( 1394 ) E:( 590 )\n", | |
"Seen: Q:( 14+791 ) T:( 805 ) \u001b[91m☒\u001b[0m G:( 1395 ) E:( 590 )\n", | |
"Q:( 15+791 ) T:( 806 ) \u001b[91m☒\u001b[0m G:( 1396 ) E:( 590 )\n", | |
"Seen: Q:( 16+791 ) T:( 807 ) \u001b[91m☒\u001b[0m G:( 1397 ) E:( 590 )\n", | |
"Q:( 17+791 ) T:( 808 ) \u001b[91m☒\u001b[0m G:( 1398 ) E:( 590 )\n", | |
"Q:( 18+791 ) T:( 809 ) \u001b[91m☒\u001b[0m G:( 1309 ) E:( 500 )\n", | |
"Q:( 19+791 ) T:( 810 ) \u001b[91m☒\u001b[0m G:( 1300 ) E:( 490 )\n", | |
"Q:( 20+791 ) T:( 811 ) \u001b[91m☒\u001b[0m G:( 1401 ) E:( 590 )\n", | |
"Seen: Q:( 21+791 ) T:( 812 ) \u001b[91m☒\u001b[0m G:( 1402 ) E:( 590 )\n", | |
"Seen: Q:( 22+791 ) T:( 813 ) \u001b[91m☒\u001b[0m G:( 1303 ) E:( 490 )\n", | |
"Q:( 23+791 ) T:( 814 ) \u001b[91m☒\u001b[0m G:( 1304 ) E:( 490 )\n", | |
"Q:( 24+791 ) T:( 815 ) \u001b[91m☒\u001b[0m G:( 1405 ) E:( 590 )\n", | |
"Seen: Q:( 25+791 ) T:( 816 ) \u001b[91m☒\u001b[0m G:( 1406 ) E:( 590 )\n", | |
"Q:( 26+791 ) T:( 817 ) \u001b[91m☒\u001b[0m G:( 1407 ) E:( 590 )\n", | |
"Seen: Q:( 27+791 ) T:( 818 ) \u001b[91m☒\u001b[0m G:( 1408 ) E:( 590 )\n", | |
"Q:( 28+791 ) T:( 819 ) \u001b[91m☒\u001b[0m G:( 1409 ) E:( 590 )\n", | |
"Q:( 29+791 ) T:( 820 ) \u001b[91m☒\u001b[0m G:( 1410 ) E:( 590 )\n", | |
"Q:( 30+791 ) T:( 821 ) \u001b[91m☒\u001b[0m G:( 1411 ) E:( 590 )\n", | |
"Q:( 31+791 ) T:( 822 ) \u001b[91m☒\u001b[0m G:( 1412 ) E:( 590 )\n", | |
"Q:( 32+791 ) T:( 823 ) \u001b[91m☒\u001b[0m G:( 1413 ) E:( 590 )\n", | |
"Q:( 33+791 ) T:( 824 ) \u001b[91m☒\u001b[0m G:( 1414 ) E:( 590 )\n", | |
"Seen: Q:( 34+791 ) T:( 825 ) \u001b[91m☒\u001b[0m G:( 1415 ) E:( 590 )\n", | |
"Q:( 35+791 ) T:( 826 ) \u001b[91m☒\u001b[0m G:( 1416 ) E:( 590 )\n", | |
"Seen: Q:( 36+791 ) T:( 827 ) \u001b[91m☒\u001b[0m G:( 1417 ) E:( 590 )\n", | |
"Seen: Q:( 37+791 ) T:( 828 ) \u001b[91m☒\u001b[0m G:( 1418 ) E:( 590 )\n", | |
"Q:( 38+791 ) T:( 829 ) \u001b[91m☒\u001b[0m G:( 1429 ) E:( 600 )\n", | |
"Seen: Q:( 39+791 ) T:( 830 ) \u001b[91m☒\u001b[0m G:( 1420 ) E:( 590 )\n", | |
"Q:( 40+791 ) T:( 831 ) \u001b[91m☒\u001b[0m G:( 1431 ) E:( 600 )\n", | |
"Q:( 41+791 ) T:( 832 ) \u001b[91m☒\u001b[0m G:( 1432 ) E:( 600 )\n", | |
"Q:( 42+791 ) T:( 833 ) \u001b[91m☒\u001b[0m G:( 1433 ) E:( 600 )\n", | |
"Q:( 43+791 ) T:( 834 ) \u001b[91m☒\u001b[0m G:( 1334 ) E:( 500 )\n", | |
"Q:( 44+791 ) T:( 835 ) \u001b[91m☒\u001b[0m G:( 1425 ) E:( 590 )\n", | |
"Seen: Q:( 45+791 ) T:( 836 ) \u001b[91m☒\u001b[0m G:( 1426 ) E:( 590 )\n", | |
"Q:( 46+791 ) T:( 837 ) \u001b[91m☒\u001b[0m G:( 1437 ) E:( 600 )\n", | |
"Q:( 47+791 ) T:( 838 ) \u001b[91m☒\u001b[0m G:( 1438 ) E:( 600 )\n", | |
"Q:( 48+791 ) T:( 839 ) \u001b[91m☒\u001b[0m G:( 1439 ) E:( 600 )\n", | |
"Q:( 49+791 ) T:( 840 ) \u001b[91m☒\u001b[0m G:( 1430 ) E:( 590 )\n", | |
"Q:( 50+791 ) T:( 841 ) \u001b[91m☒\u001b[0m G:( 1441 ) E:( 600 )\n", | |
"Q:( 51+791 ) T:( 842 ) \u001b[91m☒\u001b[0m G:( 1442 ) E:( 600 )\n", | |
"Seen: Q:( 52+791 ) T:( 843 ) \u001b[91m☒\u001b[0m G:( 1443 ) E:( 600 )\n", | |
"Seen: Q:( 53+791 ) T:( 844 ) \u001b[91m☒\u001b[0m G:( 1444 ) E:( 600 )\n", | |
"Q:( 54+791 ) T:( 845 ) \u001b[91m☒\u001b[0m G:( 1445 ) E:( 600 )\n", | |
"Seen: Q:( 55+791 ) T:( 846 ) \u001b[91m☒\u001b[0m G:( 1436 ) E:( 590 )\n", | |
"Q:( 56+791 ) T:( 847 ) \u001b[91m☒\u001b[0m G:( 1447 ) E:( 600 )\n", | |
"Q:( 57+791 ) T:( 848 ) \u001b[91m☒\u001b[0m G:( 1448 ) E:( 600 )\n", | |
"Q:( 58+791 ) T:( 849 ) \u001b[91m☒\u001b[0m G:( 1449 ) E:( 600 )\n", | |
"Q:( 59+791 ) T:( 850 ) \u001b[91m☒\u001b[0m G:( 1450 ) E:( 600 )\n", | |
"Seen: Q:( 60+791 ) T:( 851 ) \u001b[91m☒\u001b[0m G:( 1451 ) E:( 600 )\n", | |
"Q:( 61+791 ) T:( 852 ) \u001b[91m☒\u001b[0m G:( 1452 ) E:( 600 )\n", | |
"Q:( 62+791 ) T:( 853 ) \u001b[91m☒\u001b[0m G:( 1453 ) E:( 600 )\n", | |
"Seen: Q:( 63+791 ) T:( 854 ) \u001b[91m☒\u001b[0m G:( 1454 ) E:( 600 )\n", | |
"Q:( 64+791 ) T:( 855 ) \u001b[91m☒\u001b[0m G:( 1455 ) E:( 600 )\n", | |
"Q:( 65+791 ) T:( 856 ) \u001b[91m☒\u001b[0m G:( 1456 ) E:( 600 )\n", | |
"Seen: Q:( 66+791 ) T:( 857 ) \u001b[91m☒\u001b[0m G:( 1457 ) E:( 600 )\n", | |
"Q:( 67+791 ) T:( 858 ) \u001b[91m☒\u001b[0m G:( 1458 ) E:( 600 )\n", | |
"Seen: Q:( 68+791 ) T:( 859 ) \u001b[91m☒\u001b[0m G:( 1459 ) E:( 600 )\n", | |
"Q:( 69+791 ) T:( 860 ) \u001b[91m☒\u001b[0m G:( 1360 ) E:( 500 )\n", | |
"Q:( 70+791 ) T:( 861 ) \u001b[91m☒\u001b[0m G:( 1461 ) E:( 600 )\n", | |
"Q:( 71+791 ) T:( 862 ) \u001b[91m☒\u001b[0m G:( 1462 ) E:( 600 )\n", | |
"Q:( 72+791 ) T:( 863 ) \u001b[91m☒\u001b[0m G:( 1363 ) E:( 500 )\n", | |
"Q:( 73+791 ) T:( 864 ) \u001b[91m☒\u001b[0m G:( 1364 ) E:( 500 )\n", | |
"Q:( 74+791 ) T:( 865 ) \u001b[91m☒\u001b[0m G:( 1365 ) E:( 500 )\n", | |
"Q:( 75+791 ) T:( 866 ) \u001b[91m☒\u001b[0m G:( 1366 ) E:( 500 )\n", | |
"Q:( 76+791 ) T:( 867 ) \u001b[91m☒\u001b[0m G:( 1367 ) E:( 500 )\n", | |
"Q:( 77+791 ) T:( 868 ) \u001b[91m☒\u001b[0m G:( 1368 ) E:( 500 )\n", | |
"Seen: Q:( 78+791 ) T:( 869 ) \u001b[91m☒\u001b[0m G:( 1369 ) E:( 500 )\n", | |
"Q:( 79+791 ) T:( 870 ) \u001b[91m☒\u001b[0m G:( 1370 ) E:( 500 )\n", | |
"Q:( 80+791 ) T:( 871 ) \u001b[91m☒\u001b[0m G:( 1370 ) E:( 499 )\n", | |
"Seen: Q:( 81+791 ) T:( 872 ) \u001b[91m☒\u001b[0m G:( 1372 ) E:( 500 )\n", | |
"Q:( 82+791 ) T:( 873 ) \u001b[91m☒\u001b[0m G:( 1373 ) E:( 500 )\n", | |
"Q:( 83+791 ) T:( 874 ) \u001b[91m☒\u001b[0m G:( 1374 ) E:( 500 )\n", | |
"Q:( 84+791 ) T:( 875 ) \u001b[91m☒\u001b[0m G:( 1375 ) E:( 500 )\n", | |
"Seen: Q:( 85+791 ) T:( 876 ) \u001b[91m☒\u001b[0m G:( 1376 ) E:( 500 )\n", | |
"Q:( 86+791 ) T:( 877 ) \u001b[91m☒\u001b[0m G:( 1377 ) E:( 500 )\n", | |
"Q:( 87+791 ) T:( 878 ) \u001b[91m☒\u001b[0m G:( 1378 ) E:( 500 )\n", | |
"Seen: Q:( 88+791 ) T:( 879 ) \u001b[91m☒\u001b[0m G:( 1379 ) E:( 500 )\n", | |
"Q:( 89+791 ) T:( 880 ) \u001b[91m☒\u001b[0m G:( 1380 ) E:( 500 )\n", | |
"Seen: Q:( 90+791 ) T:( 881 ) \u001b[91m☒\u001b[0m G:( 1470 ) E:( 589 )\n", | |
"Q:( 91+791 ) T:( 882 ) \u001b[91m☒\u001b[0m G:( 1381 ) E:( 499 )\n", | |
"Seen: Q:( 92+791 ) T:( 883 ) \u001b[91m☒\u001b[0m G:( 1383 ) E:( 500 )\n", | |
"Q:( 93+791 ) T:( 884 ) \u001b[91m☒\u001b[0m G:( 1384 ) E:( 500 )\n", | |
"Q:( 94+791 ) T:( 885 ) \u001b[91m☒\u001b[0m G:( 1485 ) E:( 600 )\n", | |
"Q:( 95+791 ) T:( 886 ) \u001b[91m☒\u001b[0m G:( 1486 ) E:( 600 )\n", | |
"Q:( 96+791 ) T:( 887 ) \u001b[91m☒\u001b[0m G:( 1487 ) E:( 600 )\n", | |
"Q:( 97+791 ) T:( 888 ) \u001b[91m☒\u001b[0m G:( 1388 ) E:( 500 )\n", | |
"Seen: Q:( 98+791 ) T:( 889 ) \u001b[91m☒\u001b[0m G:( 1389 ) E:( 500 )\n", | |
"Q:( 99+791 ) T:( 890 ) \u001b[91m☒\u001b[0m G:( 1380 ) E:( 490 )\n", | |
"Seen: Q:( 100+791 ) T:( 891 ) \u001b[92m☑\u001b[0m G:( 891 ) E:( 0 )\n" | |
] | |
} | |
], | |
"source": [ | |
"for i in range(0,101):\n", | |
" calculate_add(i, 791)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.4.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment