Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save tnarihi/629cfd8bc1359018a56d to your computer and use it in GitHub Desktop.
Save tnarihi/629cfd8bc1359018a56d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"metadata": {
"name": "",
"signature": "sha256:086465db54efd6ff93e4d9d322a4ea50c0b28201abf0abef5b160b793c0ccc07"
},
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
{
"cells": [
{
"cell_type": "code",
"collapsed": false,
"input": [
"%pylab inline\n",
"import numpy as np\n",
"import scipy.sparse\n",
"import theano as tn\n",
"import theano.tensor as T\n",
"import theano.sparse"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"prompt_number": 1
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# Loading sample sparse dataset\n",
"from sklearn.datasets import fetch_20newsgroups_vectorized\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"ng = fetch_20newsgroups_vectorized(subset='train')\n",
"ng_test = fetch_20newsgroups_vectorized(subset='test')\n",
"# Convrert targets to 1-of-K representation\n",
"x = ng.data\n",
"tmp = ng.target\n",
"t = np.zeros((tmp.shape[0], tmp.max() + 1))\n",
"t[(xrange(tmp.shape[0]), tmp.astype(np.int32))] = 1\n",
"del tmp\n",
"# \n",
"n_class = t.shape[1]\n",
"n_features = x.shape[1]"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 2
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# Create shared weights and biases\n",
"sh_w = tn.shared(np.zeros((n_features, n_class), dtype='float32'))\n",
"sh_b = tn.shared(np.zeros((n_class,), dtype='float32'))"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 3
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# Create input tensor variables\n",
"sy_x = theano.sparse.csc_fmatrix('X')\n",
"sy_t = T.fmatrix('T')\n",
"sy_learning_rate = T.fscalar('learning_rate')\n",
"sy_momentum = T.fscalar('momentum')"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 4
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# Define classifier\n",
"sy_y = tn.dot(sy_x, sh_w) + sh_b"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 5
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# Define softmax loss function\n",
"sy_softmax = T.nnet.softmax(sy_y)\n",
"sy_loss = T.sum(T.nnet.categorical_crossentropy(sy_softmax, sy_t))"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 6
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# Define gradient\n",
"sh_params = [sh_w, sh_b]\n",
"sy_grads = [tn.grad(sy_loss, wrt=sh_p) for sh_p in sh_params]"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 7
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# Define update function\n",
"sy_updates = []\n",
"for p in xrange(len(sh_params)):\n",
" sh_p = sh_params[p]\n",
" sy_g = sy_grads[p]\n",
" sh_update = tn.shared(np.zeros_like(sh_p.get_value(borrow=True))) # momentum storage\n",
" sy_updates.append((sh_update, sy_momentum * sh_update + (1. - sy_momentum) * sy_g))\n",
" sy_updates.append((sh_p, sh_p - sy_learning_rate * sh_update))"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 8
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# Compile functions\n",
"f_classify = tn.function([sy_x], sy_y, allow_input_downcast=True)\n",
"f_update = tn.function(\n",
" [sy_x, sy_t, sy_learning_rate, sy_momentum], sy_loss,\n",
" updates=sy_updates, allow_input_downcast=True\n",
")"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 9
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# We also create evalution function\n",
"from sklearn.metrics import accuracy_score\n",
"def print_accuracy():\n",
" y = f_classify(ng_test.data)\n",
" print 'Accuracy:', accuracy_score(ng_test.target, y.argmax(axis=1))\n"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 10
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"n_epoch = 100\n",
"batch_size = 50\n",
"for epoch in xrange(n_epoch):\n",
" print '--- epoch %s ---'%epoch\n",
" print_accuracy()\n",
" loss = 0.0\n",
" for batch in xrange(0, x.shape[0], batch_size):\n",
" bend = min(batch + batch_size, x.shape[0])\n",
" loss += f_update(\n",
" x[batch:bend].tocsc(),\n",
" t[batch:bend],\n",
" 0.1,\n",
" 0.9,\n",
" )\n",
" print 'loss:', loss\n"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"--- epoch 0 ---\n",
"Accuracy: 0.0423526287839\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 29737.8746376\n",
"--- epoch 1 ---\n",
"Accuracy: 0.427243759958\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 23028.6264458\n",
"--- epoch 2 ---\n",
"Accuracy: 0.536112586298\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 19023.5092354\n",
"--- epoch 3 ---\n",
"Accuracy: 0.58775889538\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 16320.2985516\n",
"--- epoch 4 ---\n",
"Accuracy: 0.620685077005\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 14365.3398514\n",
"--- epoch 5 ---\n",
"Accuracy: 0.647769516729\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 12876.8484249\n",
"--- epoch 6 ---\n",
"Accuracy: 0.664630908125\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 11698.8075142\n",
"--- epoch 7 ---\n",
"Accuracy: 0.679899097185\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 10738.5565882\n",
"--- epoch 8 ---\n",
"Accuracy: 0.692113648433\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 9937.64622021\n",
"--- epoch 9 ---\n",
"Accuracy: 0.701672862454\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 9257.27986622\n",
"--- epoch 10 ---\n",
"Accuracy: 0.708178438662\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 8670.59211063\n",
"--- epoch 11 ---\n",
"Accuracy: 0.715878916622\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 8158.32378674\n",
"--- epoch 12 ---\n",
"Accuracy: 0.722251725969\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 7706.27385426\n",
"--- epoch 13 ---\n",
"Accuracy: 0.725969198088\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 7303.7234602\n",
"--- epoch 14 ---\n",
"Accuracy: 0.73088157196\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 6942.42311192\n",
"--- epoch 15 ---\n",
"Accuracy: 0.734333510356\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 6615.91445255\n",
"--- epoch 16 ---\n",
"Accuracy: 0.738316516198\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 6319.06619453\n",
"--- epoch 17 ---\n",
"Accuracy: 0.74044078598\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 6047.74616909\n",
"--- epoch 18 ---\n",
"Accuracy: 0.744556558683\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 5798.5874095\n",
"--- epoch 19 ---\n",
"Accuracy: 0.747742963356\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 5568.81564426\n",
"--- epoch 20 ---\n",
"Accuracy: 0.750663834307\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 5356.12279034\n",
"--- epoch 21 ---\n",
"Accuracy: 0.752788104089\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 5158.57047272\n",
"--- epoch 22 ---\n",
"Accuracy: 0.754248539565\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 4974.51731491\n",
"--- epoch 23 ---\n",
"Accuracy: 0.757434944238\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 4802.56239128\n",
"--- epoch 24 ---\n",
"Accuracy: 0.759028146575\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 4641.50262165\n",
"--- epoch 25 ---\n",
"Accuracy: 0.760355815189\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 4490.29653549\n",
"--- epoch 26 ---\n",
"Accuracy: 0.762214551248\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 4348.03818083\n",
"--- epoch 27 ---\n",
"Accuracy: 0.76526818906\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 4213.93443489\n",
"--- epoch 28 ---\n",
"Accuracy: 0.767259691981\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 4087.28796434\n",
"--- epoch 29 ---\n",
"Accuracy: 0.768056293149\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 3967.48210478\n",
"--- epoch 30 ---\n",
"Accuracy: 0.769782262347\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 3853.9699769\n",
"--- epoch 31 ---\n",
"Accuracy: 0.77190653213\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 3746.2632103\n",
"--- epoch 32 ---\n",
"Accuracy: 0.772835900159\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 3643.92566395\n",
"--- epoch 33 ---\n",
"Accuracy: 0.774030801912\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 3546.56490088\n",
"--- epoch 34 ---\n",
"Accuracy: 0.77482740308\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 3453.82742167\n",
"--- epoch 35 ---\n",
"Accuracy: 0.77575677111\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 3365.39394426\n",
"--- epoch 36 ---\n",
"Accuracy: 0.776420605417\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 3280.97470403\n",
"--- epoch 37 ---\n",
"Accuracy: 0.777349973447\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 3200.3061378\n",
"--- epoch 38 ---\n",
"Accuracy: 0.778412108338\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 3123.14804721\n",
"--- epoch 39 ---\n",
"Accuracy: 0.779208709506\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 3049.28037572\n",
"--- epoch 40 ---\n",
"Accuracy: 0.780005310674\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2978.50256801\n",
"--- epoch 41 ---\n",
"Accuracy: 0.780801911843\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2910.62931991\n",
"--- epoch 42 ---\n",
"Accuracy: 0.780801911843\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2845.49062061\n",
"--- epoch 43 ---\n",
"Accuracy: 0.781200212427\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2782.92945194\n",
"--- epoch 44 ---\n",
"Accuracy: 0.781864046734\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2722.80122304\n",
"--- epoch 45 ---\n",
"Accuracy: 0.781996813595\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2664.97131395\n",
"--- epoch 46 ---\n",
"Accuracy: 0.782793414764\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2609.3156178\n",
"--- epoch 47 ---\n",
"Accuracy: 0.783324482209\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2555.71850443\n",
"--- epoch 48 ---\n",
"Accuracy: 0.784253850239\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2504.07247043\n",
"--- epoch 49 ---\n",
"Accuracy: 0.784917684546\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2454.27749181\n",
"--- epoch 50 ---\n",
"Accuracy: 0.78531598513\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2406.2400744\n",
"--- epoch 51 ---\n",
"Accuracy: 0.785847052576\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2359.87282801\n",
"--- epoch 52 ---\n",
"Accuracy: 0.785847052576\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2315.09441638\n",
"--- epoch 53 ---\n",
"Accuracy: 0.78624535316\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2271.82822204\n",
"--- epoch 54 ---\n",
"Accuracy: 0.786378120021\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2230.00282717\n",
"--- epoch 55 ---\n",
"Accuracy: 0.786643653744\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2189.55077243\n",
"--- epoch 56 ---\n",
"Accuracy: 0.786776420605\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2150.40886998\n",
"--- epoch 57 ---\n",
"Accuracy: 0.78717472119\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2112.51788092\n",
"--- epoch 58 ---\n",
"Accuracy: 0.787307488051\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2075.8217051\n",
"--- epoch 59 ---\n",
"Accuracy: 0.786776420605\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2040.26780605\n",
"--- epoch 60 ---\n",
"Accuracy: 0.786909187467\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 2005.8063581\n",
"--- epoch 61 ---\n",
"Accuracy: 0.787971322358\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1972.39035869\n",
"--- epoch 62 ---\n",
"Accuracy: 0.787838555497\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1939.97550058\n",
"--- epoch 63 ---\n",
"Accuracy: 0.787971322358\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1908.52000976\n",
"--- epoch 64 ---\n",
"Accuracy: 0.788502389804\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1877.98406339\n",
"--- epoch 65 ---\n",
"Accuracy: 0.788635156665\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1848.33012223\n",
"--- epoch 66 ---\n",
"Accuracy: 0.788635156665\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1819.52257264\n",
"--- epoch 67 ---\n",
"Accuracy: 0.789033457249\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1791.52758956\n",
"--- epoch 68 ---\n",
"Accuracy: 0.789033457249\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1764.31316543\n",
"--- epoch 69 ---\n",
"Accuracy: 0.789033457249\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1737.84877717\n",
"--- epoch 70 ---\n",
"Accuracy: 0.789298990972\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1712.10554111\n",
"--- epoch 71 ---\n",
"Accuracy: 0.789298990972\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1687.05614507\n",
"--- epoch 72 ---\n",
"Accuracy: 0.789830058417\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1662.67408979\n",
"--- epoch 73 ---\n",
"Accuracy: 0.79009559214\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1638.93454897\n",
"--- epoch 74 ---\n",
"Accuracy: 0.790493892724\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1615.81393421\n",
"--- epoch 75 ---\n",
"Accuracy: 0.790493892724\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1593.28945303\n",
"--- epoch 76 ---\n",
"Accuracy: 0.790493892724\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1571.33972526\n",
"--- epoch 77 ---\n",
"Accuracy: 0.790228359002\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1549.94396472\n",
"--- epoch 78 ---\n",
"Accuracy: 0.790626659586\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1529.08267617\n",
"--- epoch 79 ---\n",
"Accuracy: 0.790626659586\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1508.73700202\n",
"--- epoch 80 ---\n",
"Accuracy: 0.790759426447\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1488.88911879\n",
"--- epoch 81 ---\n",
"Accuracy: 0.79102496017\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1469.52181292\n",
"--- epoch 82 ---\n",
"Accuracy: 0.790759426447\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1450.61870253\n",
"--- epoch 83 ---\n",
"Accuracy: 0.790892193309\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1432.16414738\n",
"--- epoch 84 ---\n",
"Accuracy: 0.79102496017\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1414.14307666\n",
"--- epoch 85 ---\n",
"Accuracy: 0.791157727031\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1396.54127836\n",
"--- epoch 86 ---\n",
"Accuracy: 0.791290493893\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1379.344805\n",
"--- epoch 87 ---\n",
"Accuracy: 0.791688794477\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1362.54072309\n",
"--- epoch 88 ---\n",
"Accuracy: 0.791821561338\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1346.1160171\n",
"--- epoch 89 ---\n",
"Accuracy: 0.7919543282\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1330.05895436\n",
"--- epoch 90 ---\n",
"Accuracy: 0.792087095061\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1314.35758257\n",
"--- epoch 91 ---\n",
"Accuracy: 0.792352628784\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1299.00088787\n",
"--- epoch 92 ---\n",
"Accuracy: 0.792352628784\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1283.97807586\n",
"--- epoch 93 ---\n",
"Accuracy: 0.792485395645\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1269.2790277\n",
"--- epoch 94 ---\n",
"Accuracy: 0.792750929368\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1254.8936621\n",
"--- epoch 95 ---\n",
"Accuracy: 0.792618162507\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1240.81245065\n",
"--- epoch 96 ---\n",
"Accuracy: 0.792750929368\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1227.02650273\n",
"--- epoch 97 ---\n",
"Accuracy: 0.793016463091\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1213.5268079\n",
"--- epoch 98 ---\n",
"Accuracy: 0.793414763675\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1200.30496955\n",
"--- epoch 99 ---\n",
"Accuracy: 0.793813064259\n",
"loss:"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 1187.35288036\n"
]
}
],
"prompt_number": 11
},
{
"cell_type": "code",
"collapsed": false,
"input": [],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 11
}
],
"metadata": {}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment