Created
October 22, 2014 19:18
-
-
Save tnarihi/629cfd8bc1359018a56d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"name": "", | |
"signature": "sha256:086465db54efd6ff93e4d9d322a4ea50c0b28201abf0abef5b160b793c0ccc07" | |
}, | |
"nbformat": 3, | |
"nbformat_minor": 0, | |
"worksheets": [ | |
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"%pylab inline\n", | |
"import numpy as np\n", | |
"import scipy.sparse\n", | |
"import theano as tn\n", | |
"import theano.tensor as T\n", | |
"import theano.sparse" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"Populating the interactive namespace from numpy and matplotlib\n" | |
] | |
} | |
], | |
"prompt_number": 1 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# Loading sample sparse dataset\n", | |
"from sklearn.datasets import fetch_20newsgroups_vectorized\n", | |
"from sklearn.feature_extraction.text import TfidfVectorizer\n", | |
"ng = fetch_20newsgroups_vectorized(subset='train')\n", | |
"ng_test = fetch_20newsgroups_vectorized(subset='test')\n", | |
"# Convrert targets to 1-of-K representation\n", | |
"x = ng.data\n", | |
"tmp = ng.target\n", | |
"t = np.zeros((tmp.shape[0], tmp.max() + 1))\n", | |
"t[(xrange(tmp.shape[0]), tmp.astype(np.int32))] = 1\n", | |
"del tmp\n", | |
"# \n", | |
"n_class = t.shape[1]\n", | |
"n_features = x.shape[1]" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 2 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# Create shared weights and biases\n", | |
"sh_w = tn.shared(np.zeros((n_features, n_class), dtype='float32'))\n", | |
"sh_b = tn.shared(np.zeros((n_class,), dtype='float32'))" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 3 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# Create input tensor variables\n", | |
"sy_x = theano.sparse.csc_fmatrix('X')\n", | |
"sy_t = T.fmatrix('T')\n", | |
"sy_learning_rate = T.fscalar('learning_rate')\n", | |
"sy_momentum = T.fscalar('momentum')" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 4 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# Define classifier\n", | |
"sy_y = tn.dot(sy_x, sh_w) + sh_b" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 5 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# Define softmax loss function\n", | |
"sy_softmax = T.nnet.softmax(sy_y)\n", | |
"sy_loss = T.sum(T.nnet.categorical_crossentropy(sy_softmax, sy_t))" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 6 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# Define gradient\n", | |
"sh_params = [sh_w, sh_b]\n", | |
"sy_grads = [tn.grad(sy_loss, wrt=sh_p) for sh_p in sh_params]" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 7 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# Define update function\n", | |
"sy_updates = []\n", | |
"for p in xrange(len(sh_params)):\n", | |
" sh_p = sh_params[p]\n", | |
" sy_g = sy_grads[p]\n", | |
" sh_update = tn.shared(np.zeros_like(sh_p.get_value(borrow=True))) # momentum storage\n", | |
" sy_updates.append((sh_update, sy_momentum * sh_update + (1. - sy_momentum) * sy_g))\n", | |
" sy_updates.append((sh_p, sh_p - sy_learning_rate * sh_update))" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 8 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# Compile functions\n", | |
"f_classify = tn.function([sy_x], sy_y, allow_input_downcast=True)\n", | |
"f_update = tn.function(\n", | |
" [sy_x, sy_t, sy_learning_rate, sy_momentum], sy_loss,\n", | |
" updates=sy_updates, allow_input_downcast=True\n", | |
")" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 9 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# We also create evalution function\n", | |
"from sklearn.metrics import accuracy_score\n", | |
"def print_accuracy():\n", | |
" y = f_classify(ng_test.data)\n", | |
" print 'Accuracy:', accuracy_score(ng_test.target, y.argmax(axis=1))\n" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 10 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"n_epoch = 100\n", | |
"batch_size = 50\n", | |
"for epoch in xrange(n_epoch):\n", | |
" print '--- epoch %s ---'%epoch\n", | |
" print_accuracy()\n", | |
" loss = 0.0\n", | |
" for batch in xrange(0, x.shape[0], batch_size):\n", | |
" bend = min(batch + batch_size, x.shape[0])\n", | |
" loss += f_update(\n", | |
" x[batch:bend].tocsc(),\n", | |
" t[batch:bend],\n", | |
" 0.1,\n", | |
" 0.9,\n", | |
" )\n", | |
" print 'loss:', loss\n" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"--- epoch 0 ---\n", | |
"Accuracy: 0.0423526287839\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 29737.8746376\n", | |
"--- epoch 1 ---\n", | |
"Accuracy: 0.427243759958\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 23028.6264458\n", | |
"--- epoch 2 ---\n", | |
"Accuracy: 0.536112586298\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 19023.5092354\n", | |
"--- epoch 3 ---\n", | |
"Accuracy: 0.58775889538\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 16320.2985516\n", | |
"--- epoch 4 ---\n", | |
"Accuracy: 0.620685077005\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 14365.3398514\n", | |
"--- epoch 5 ---\n", | |
"Accuracy: 0.647769516729\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 12876.8484249\n", | |
"--- epoch 6 ---\n", | |
"Accuracy: 0.664630908125\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 11698.8075142\n", | |
"--- epoch 7 ---\n", | |
"Accuracy: 0.679899097185\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 10738.5565882\n", | |
"--- epoch 8 ---\n", | |
"Accuracy: 0.692113648433\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 9937.64622021\n", | |
"--- epoch 9 ---\n", | |
"Accuracy: 0.701672862454\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 9257.27986622\n", | |
"--- epoch 10 ---\n", | |
"Accuracy: 0.708178438662\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 8670.59211063\n", | |
"--- epoch 11 ---\n", | |
"Accuracy: 0.715878916622\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 8158.32378674\n", | |
"--- epoch 12 ---\n", | |
"Accuracy: 0.722251725969\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 7706.27385426\n", | |
"--- epoch 13 ---\n", | |
"Accuracy: 0.725969198088\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 7303.7234602\n", | |
"--- epoch 14 ---\n", | |
"Accuracy: 0.73088157196\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 6942.42311192\n", | |
"--- epoch 15 ---\n", | |
"Accuracy: 0.734333510356\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 6615.91445255\n", | |
"--- epoch 16 ---\n", | |
"Accuracy: 0.738316516198\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 6319.06619453\n", | |
"--- epoch 17 ---\n", | |
"Accuracy: 0.74044078598\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 6047.74616909\n", | |
"--- epoch 18 ---\n", | |
"Accuracy: 0.744556558683\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 5798.5874095\n", | |
"--- epoch 19 ---\n", | |
"Accuracy: 0.747742963356\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 5568.81564426\n", | |
"--- epoch 20 ---\n", | |
"Accuracy: 0.750663834307\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 5356.12279034\n", | |
"--- epoch 21 ---\n", | |
"Accuracy: 0.752788104089\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 5158.57047272\n", | |
"--- epoch 22 ---\n", | |
"Accuracy: 0.754248539565\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 4974.51731491\n", | |
"--- epoch 23 ---\n", | |
"Accuracy: 0.757434944238\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 4802.56239128\n", | |
"--- epoch 24 ---\n", | |
"Accuracy: 0.759028146575\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 4641.50262165\n", | |
"--- epoch 25 ---\n", | |
"Accuracy: 0.760355815189\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 4490.29653549\n", | |
"--- epoch 26 ---\n", | |
"Accuracy: 0.762214551248\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 4348.03818083\n", | |
"--- epoch 27 ---\n", | |
"Accuracy: 0.76526818906\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 4213.93443489\n", | |
"--- epoch 28 ---\n", | |
"Accuracy: 0.767259691981\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 4087.28796434\n", | |
"--- epoch 29 ---\n", | |
"Accuracy: 0.768056293149\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 3967.48210478\n", | |
"--- epoch 30 ---\n", | |
"Accuracy: 0.769782262347\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 3853.9699769\n", | |
"--- epoch 31 ---\n", | |
"Accuracy: 0.77190653213\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 3746.2632103\n", | |
"--- epoch 32 ---\n", | |
"Accuracy: 0.772835900159\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 3643.92566395\n", | |
"--- epoch 33 ---\n", | |
"Accuracy: 0.774030801912\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 3546.56490088\n", | |
"--- epoch 34 ---\n", | |
"Accuracy: 0.77482740308\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 3453.82742167\n", | |
"--- epoch 35 ---\n", | |
"Accuracy: 0.77575677111\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 3365.39394426\n", | |
"--- epoch 36 ---\n", | |
"Accuracy: 0.776420605417\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 3280.97470403\n", | |
"--- epoch 37 ---\n", | |
"Accuracy: 0.777349973447\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 3200.3061378\n", | |
"--- epoch 38 ---\n", | |
"Accuracy: 0.778412108338\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 3123.14804721\n", | |
"--- epoch 39 ---\n", | |
"Accuracy: 0.779208709506\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 3049.28037572\n", | |
"--- epoch 40 ---\n", | |
"Accuracy: 0.780005310674\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2978.50256801\n", | |
"--- epoch 41 ---\n", | |
"Accuracy: 0.780801911843\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2910.62931991\n", | |
"--- epoch 42 ---\n", | |
"Accuracy: 0.780801911843\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2845.49062061\n", | |
"--- epoch 43 ---\n", | |
"Accuracy: 0.781200212427\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2782.92945194\n", | |
"--- epoch 44 ---\n", | |
"Accuracy: 0.781864046734\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2722.80122304\n", | |
"--- epoch 45 ---\n", | |
"Accuracy: 0.781996813595\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2664.97131395\n", | |
"--- epoch 46 ---\n", | |
"Accuracy: 0.782793414764\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2609.3156178\n", | |
"--- epoch 47 ---\n", | |
"Accuracy: 0.783324482209\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2555.71850443\n", | |
"--- epoch 48 ---\n", | |
"Accuracy: 0.784253850239\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2504.07247043\n", | |
"--- epoch 49 ---\n", | |
"Accuracy: 0.784917684546\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2454.27749181\n", | |
"--- epoch 50 ---\n", | |
"Accuracy: 0.78531598513\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2406.2400744\n", | |
"--- epoch 51 ---\n", | |
"Accuracy: 0.785847052576\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2359.87282801\n", | |
"--- epoch 52 ---\n", | |
"Accuracy: 0.785847052576\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2315.09441638\n", | |
"--- epoch 53 ---\n", | |
"Accuracy: 0.78624535316\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2271.82822204\n", | |
"--- epoch 54 ---\n", | |
"Accuracy: 0.786378120021\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2230.00282717\n", | |
"--- epoch 55 ---\n", | |
"Accuracy: 0.786643653744\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2189.55077243\n", | |
"--- epoch 56 ---\n", | |
"Accuracy: 0.786776420605\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2150.40886998\n", | |
"--- epoch 57 ---\n", | |
"Accuracy: 0.78717472119\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2112.51788092\n", | |
"--- epoch 58 ---\n", | |
"Accuracy: 0.787307488051\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2075.8217051\n", | |
"--- epoch 59 ---\n", | |
"Accuracy: 0.786776420605\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2040.26780605\n", | |
"--- epoch 60 ---\n", | |
"Accuracy: 0.786909187467\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 2005.8063581\n", | |
"--- epoch 61 ---\n", | |
"Accuracy: 0.787971322358\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1972.39035869\n", | |
"--- epoch 62 ---\n", | |
"Accuracy: 0.787838555497\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1939.97550058\n", | |
"--- epoch 63 ---\n", | |
"Accuracy: 0.787971322358\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1908.52000976\n", | |
"--- epoch 64 ---\n", | |
"Accuracy: 0.788502389804\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1877.98406339\n", | |
"--- epoch 65 ---\n", | |
"Accuracy: 0.788635156665\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1848.33012223\n", | |
"--- epoch 66 ---\n", | |
"Accuracy: 0.788635156665\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1819.52257264\n", | |
"--- epoch 67 ---\n", | |
"Accuracy: 0.789033457249\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1791.52758956\n", | |
"--- epoch 68 ---\n", | |
"Accuracy: 0.789033457249\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1764.31316543\n", | |
"--- epoch 69 ---\n", | |
"Accuracy: 0.789033457249\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1737.84877717\n", | |
"--- epoch 70 ---\n", | |
"Accuracy: 0.789298990972\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1712.10554111\n", | |
"--- epoch 71 ---\n", | |
"Accuracy: 0.789298990972\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1687.05614507\n", | |
"--- epoch 72 ---\n", | |
"Accuracy: 0.789830058417\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1662.67408979\n", | |
"--- epoch 73 ---\n", | |
"Accuracy: 0.79009559214\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1638.93454897\n", | |
"--- epoch 74 ---\n", | |
"Accuracy: 0.790493892724\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1615.81393421\n", | |
"--- epoch 75 ---\n", | |
"Accuracy: 0.790493892724\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1593.28945303\n", | |
"--- epoch 76 ---\n", | |
"Accuracy: 0.790493892724\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1571.33972526\n", | |
"--- epoch 77 ---\n", | |
"Accuracy: 0.790228359002\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1549.94396472\n", | |
"--- epoch 78 ---\n", | |
"Accuracy: 0.790626659586\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1529.08267617\n", | |
"--- epoch 79 ---\n", | |
"Accuracy: 0.790626659586\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1508.73700202\n", | |
"--- epoch 80 ---\n", | |
"Accuracy: 0.790759426447\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1488.88911879\n", | |
"--- epoch 81 ---\n", | |
"Accuracy: 0.79102496017\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1469.52181292\n", | |
"--- epoch 82 ---\n", | |
"Accuracy: 0.790759426447\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1450.61870253\n", | |
"--- epoch 83 ---\n", | |
"Accuracy: 0.790892193309\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1432.16414738\n", | |
"--- epoch 84 ---\n", | |
"Accuracy: 0.79102496017\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1414.14307666\n", | |
"--- epoch 85 ---\n", | |
"Accuracy: 0.791157727031\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1396.54127836\n", | |
"--- epoch 86 ---\n", | |
"Accuracy: 0.791290493893\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1379.344805\n", | |
"--- epoch 87 ---\n", | |
"Accuracy: 0.791688794477\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1362.54072309\n", | |
"--- epoch 88 ---\n", | |
"Accuracy: 0.791821561338\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1346.1160171\n", | |
"--- epoch 89 ---\n", | |
"Accuracy: 0.7919543282\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1330.05895436\n", | |
"--- epoch 90 ---\n", | |
"Accuracy: 0.792087095061\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1314.35758257\n", | |
"--- epoch 91 ---\n", | |
"Accuracy: 0.792352628784\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1299.00088787\n", | |
"--- epoch 92 ---\n", | |
"Accuracy: 0.792352628784\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1283.97807586\n", | |
"--- epoch 93 ---\n", | |
"Accuracy: 0.792485395645\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1269.2790277\n", | |
"--- epoch 94 ---\n", | |
"Accuracy: 0.792750929368\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1254.8936621\n", | |
"--- epoch 95 ---\n", | |
"Accuracy: 0.792618162507\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1240.81245065\n", | |
"--- epoch 96 ---\n", | |
"Accuracy: 0.792750929368\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1227.02650273\n", | |
"--- epoch 97 ---\n", | |
"Accuracy: 0.793016463091\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1213.5268079\n", | |
"--- epoch 98 ---\n", | |
"Accuracy: 0.793414763675\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1200.30496955\n", | |
"--- epoch 99 ---\n", | |
"Accuracy: 0.793813064259\n", | |
"loss:" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" 1187.35288036\n" | |
] | |
} | |
], | |
"prompt_number": 11 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 11 | |
} | |
], | |
"metadata": {} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment