Skip to content

Instantly share code, notes, and snippets.

@sjchoi86
Created April 7, 2020 01:06
Show Gist options
  • Save sjchoi86/34b4c024ae633b66909474550019c1da to your computer and use it in GitHub Desktop.
Save sjchoi86/34b4c024ae633b66909474550019c1da to your computer and use it in GitHub Desktop.
vibroptml/scripts/demo_tf_rnn.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "### Basic Classification using LSTM"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import numpy as np\nimport tensorflow as tf\nimport matplotlib.pyplot as plt\n%matplotlib inline \nprint (tf.__version__)",
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": "1.12.0\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Define an LSTM classifier"
},
{
"metadata": {
"trusted": true,
"scrolled": true
},
"cell_type": "code",
"source": "tf.reset_default_graph() # reset graph \n# Define RNN \nrnn_d_state,rnn_n_timesteps,rnn_d_input = 16,28,28 # state dim, time steps, input dim\nlstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=rnn_d_state) # we use LSTM\nrnn_x = tf.placeholder(tf.float32, [None,None,rnn_d_input]) # [N x time_step x dim]\nrnn_istate = tf.placeholder(tf.float32, [None,rnn_d_state]) # initial state\nrnn_ioutput = tf.placeholder(tf.float32, [None,rnn_d_state]) # initial output\nrnn_outputs,rnn_final_state = tf.nn.dynamic_rnn(cell=lstm_cell,inputs=rnn_x,dtype=tf.float32,\n initial_state=tf.nn.rnn_cell.LSTMStateTuple(rnn_istate,rnn_ioutput))\nrnn_last_output = rnn_outputs[:,-1,:] # use the last output\n# Define MLP\nrnn_d_output = 10 # output dim\nmodel_hid = tf.layers.dense(rnn_last_output,32,activation=tf.nn.relu) \nmodel_out = tf.layers.dense(model_hid,rnn_d_output,activation=None)\ny = tf.placeholder(tf.float32, [None,rnn_d_output]) # target\n# Loss\ncost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=model_out,labels=y)) \noptm = tf.train.AdamOptimizer(1e-3).minimize(cost)\naccr = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(model_out,1),tf.argmax(y,1)),tf.float32))\nprint (\"Done.\")",
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": "Done.\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Load MNIST"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "(train_images,train_labels),(test_images,test_labels) = tf.keras.datasets.mnist.load_data()\ndef to_onehot(vec,n):\n return np.eye(n)[vec]\ntrain_images,test_images = train_images.astype(np.float32)/255.,test_images.astype(np.float32)/255.\ntrain_labels,test_labels = to_onehot(train_labels,10),to_onehot(test_labels,10)\nn_train,n_test = train_images.shape[0],test_images.shape[0]\nprint (\"train_images:%s train_labels:%s test_images:%s test_labels:%s\"%\n (train_images.shape,train_labels.shape,test_images.shape,test_labels.shape))",
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": "train_images:(60000, 28, 28) train_labels:(60000, 10) test_images:(10000, 28, 28) test_labels:(10000, 10)\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def plot_img_label(img,label):\n plt.imshow(img)\n plt.colorbar()\n plt.title(\"Label: %d\"%(np.argmax(label)))\n plt.show()\nr_idx = np.random.permutation(n_train)[0] # random index\nimg,label = train_images[r_idx,:,:],train_labels[r_idx,:]\nplot_img_label(img,label) # plot random image",
"execution_count": 4,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAS4AAAEICAYAAADhtRloAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAGR9JREFUeJzt3XuUHnWd5/H3J01IL+FiYhzMJOEiZsTgaNQs6IGjcREJHDU67kbi6uAMs3HPGkcd0GXZWWQYnUFUvGbZbSXLxQuwXCSrWQLDwEFHwSQImIuM2cglmZBAEiAggXT3d/+oijx9eaqq+3merqrO53VOnX6qvlW/+vUD/c3v96tfVSkiMDOrkwllV8DMbKScuMysdpy4zKx2nLjMrHacuMysdpy4zKx2nLgOYJLukvQXY32sWaucuMYBSQ9LemfZ9cgi6dOSHpf0jKTlkiaVXSerLycu6zhJpwPnA6cCRwOvAv6m1EpZrTlxjWOSpkj6kaQnJO1OP88ctNtxkn6RtoRukTS14fi3SPqZpKckPSBp/iircjZwRUSsj4jdwN8CHx1lWWZOXOPcBOB/kbRyjgKeB741aJ8/Bf4cmA70At8AkDQD+DHweWAqcB5wo6RXDD6JpKPS5HZUk3qcADzQsP4AcKSkl4/y97IDnBPXOBYROyPixoj4XUTsAb4AvH3QbtdExLqIeA74b8AiSV3Ah4GVEbEyIvoj4nZgDXDmMOd5NCJeFhGPNqnKocDTDev7Px/Wwq9nB7CDyq6AdY6kQ4CvAguAKenmwyR1RURfuv5YwyGPABOBaSSttH8n6T0N8YnAnaOoyrPA4Q3r+z/vGUVZZm5xjXPnAq8BToqIw4G3pdvVsM+shs9HAfuAJ0kS2jVpS2r/MjkiLhlFPdYDb2hYfwOwPSJ2jqIsMyeucWSipO6G5SCSrtjzwFPpoPvnhjnuw5LmpK2zi4Eb0tbYd4H3SDpdUlda5vxhBveLuBo4Jz3Py4C/Bq4czS9pBk5c48lKkiS1f7kI+Brwr0haUPcAtw5z3DUkSeRxoBv4S4CIeAxYCFwAPEHSAvsMw/w/kw7OP9tscD4ibgUuJelmPkrSJR0uiZoVIj9I0Mzqxi0uM6sdJy4z65j09q4dktY1iUvSNyRtkvSgpDcVKdeJy8w66UqS6TjNnAHMTpclwOVFCnXiMrOOiYi7gV0ZuywEro7EPcDLJE3PK3dMJ6AerEnRzeSxPKXZAWUvz/FivKD8PZs7/R2TY+euvvwdgbUPvrAe2NuwqSciekZwuhkMnAS9Jd22LeuglhKXpAXA14Eu4Dt5kxO7mcxJOrWVU5pZhnvjjpbL2Lmrj1+sanbb6UBd03+zNyLmtXzSERp14krvZ1sGnEaSJVdLWhERG9pVOTMbewH00z9Wp9vKwLs3ZqbbMrUyxnUisCkiNkfEi8C1JP1VM6uxINgXfYWWNlgB/Gl6dfEtwNMRkdlNhNa6isP1TU8avJOkJSRXC+jmkBZOZ2ZjpV0tLkk/AOYD0yRtIbljYiJARPwPkjs+zgQ2Ab8D/qxIuR0fnE8H6noADtdUT9M3q7gg6GvTHTURsTgnHsDHR1puK4lrVH1TM6u+fqrdxmglca0GZks6liRhnQV8qC21MrPSBNA3XhNXRPRKWgqsIpkOsTwi1retZmZWmvHc4iIiVpIMrpnZOBHAvoo/NcaPbjazAYIYv11FMxunAvqqnbecuMxsoGTmfLU5cZnZIKKPlu7T7jgnLjMbIBmcd+IysxpJ5nE5cZlZzfS7xWVmdeIWl5nVTiD6Kv5UdycuMxvCXUUzq5VAvBhdZVcjkxOXmQ2QTEB1V9HMasaD82ZWKxGiL9ziMrOa6XeLy8zqJBmcr3ZqqHbtzGzMeXDezGqpz/O4zKxOPHPezGqp31cVzaxOkpusnbisxg6a8YeZ8U3/6ejM+MoPf6lp7LiJh2Yee9rG92Sfe/MrM+NZXnv+/8uM9+3cNeqy6y4Q+3zLj5nVSQSegGpmdSNPQDWzegnc4jKzGvLgvJnVSiA/SNDM6iV5PVm1U0O1a2dmJfALYa3i9nzwLZnx8z9/dWb89EOezjnDpKaRfdGXeeSPjr8pu+jjc06dYenr52fG/+VPsuev9W79l9GfvOKCcT5zXtLDwB6gD+iNiHntqJSZlavqLa52pNV3RMRcJy2z8SFC9MeEQksRkhZIekjSJknnDxM/StKdkn4p6UFJZ+aV6a6imQ2QDM6355YfSV3AMuA0YAuwWtKKiNjQsNtfA9dHxOWS5gArgWOyym21xRXAbZLWSlrSpOJLJK2RtGYfL7R4OjPrvOSZ80WWAk4ENkXE5oh4EbgWWDhonwAOTz8fAeQOILba4jolIrZK+gPgdkm/joi7B9QoogfoAThcU6PF85lZhyWD84XHuKZJWtOw3pP+ze83A3isYX0LcNKgMi4iaQB9ApgMvDPvpC0lrojYmv7cIelmkux6d/ZRZlZ1I5g5/2QbxrcXA1dGxFckvRW4RtLrIqK/2QGj7ipKmizpsP2fgXcB60ZbnplVw/6Z80WWArYCsxrWZ6bbGp0DXA8QET8HuoFpWYW20uI6ErhZ0v5yvh8Rt7ZQnnXAUx95a2b8p5d8KzPeT9N/9AqZc90nmsaOui17HleeXa+ZmBn/P+de2jT232dmdwze9p1FmfHDz8gM114bX5axGpgt6ViShHUW8KFB+zwKnApcKem1JInriaxCR524ImIz8IbRHm9m1RQB+/rbk7giolfSUmAV0AUsj4j1ki4G1kTECuBc4NuSPk0yxPbRiMgcD/d0CDMbIOkqtm/mfESsJJni0LjtwobPG4CTR1KmE5eZDVH1mfNOXGY2wAinQ5TCicvMBmlvV7ETnLjMbAg/c946LmvKw21/f1nO0QdnRu98PvsVYud955zM+Owv/6JpLHp7M4/N88qcyTeLdp7XNPaTLy7LPPaW12U/zudD72g+zQOg6877MuNVllxV9OvJzKxG/OhmM6sldxXNrFZ8VdHMaslXFc2sViJErxOXmdWNu4pmVise47K2iJPnZsb/7qKeprFDlD1P67bnJ2fGv/nBD2TGZ6z9WWa8zEfeHvHde5rGJnwx+w9zyoTuzPjT5z6bGZ96Z2a48py4zKxWPI/LzGrJ87jMrFYioLdNDxLsFCcuMxvCXUUzqxWPcZlZLYUTl5nVjQfnrWVPXfBcZvyU7r1NY3kvF8ubpxVr1+eUUE8Lfj34LfAD/ej4m8aoJtUT4TEuM6sd0eerimZWNx7jMrNa8b2KZlY/kYxzVZkTl5kN4auKZlYr4cF5M6sjdxUt1953n5gZ/9ac7HcAZplzXfb7/169tvkzq+zAVfWrirntQUnLJe2QtK5h21RJt0v6TfpzSmeraWZjJSJJXEWWshTpyF4JLBi07XzgjoiYDdyRrpvZONEfKrSUJTdxRcTdwK5BmxcCV6WfrwLe1+Z6mVmJIootZRntGNeREbEt/fw4cGSzHSUtAZYAdHPIKE9nZmMlEP0Vv6rYcu0iIsh4J0JE9ETEvIiYN5FJrZ7OzMZAFFzKMtrEtV3SdID05472VcnMStXmwXlJCyQ9JGmTpGHHwyUtkrRB0npJ388rc7SJawVwdvr5bOCWUZZjZlXUpiaXpC5gGXAGMAdYLGnOoH1mA/8FODkiTgA+lVdu7hiXpB8A84FpkrYAnwMuAa6XdA7wCLAo/1ewZpYvuywzPvOg7C72mRubP1Nr9mfXZB5b8XmGHbPy+B9mxvOeYzbetXGqw4nApojYDCDpWpKLexsa9vkPwLKI2J2cO3J7cLmJKyIWNwmdmnesmdVPAP39hRPXNEmN/zr2RETjG4pnAI81rG8BThpUxh8BSPonoAu4KCJuzTqpZ86b2UABFG9xPRkR81o840HAbJKe3Uzgbkl/HBFPNTug2tc8zawUbZzHtRWY1bA+M93WaAuwIiL2RcRvgX8mSWRNOXGZ2VDtmw+xGpgt6VhJBwNnkVzca/RDktYWkqaRdB03ZxXqrqKZDdK++xAjolfSUmAVyfjV8ohYL+liYE1ErEhj75K0AegDPhMRO7PKdeIys6HaeLk5IlYCKwdtu7DhcwB/lS6FOHG1Qdfhh2fG+27Ojh9zUPatULv7n8+M71g1s2nsD3sfzTx2POt6+dSmsQm5T/jMHkV55rnuzHjzM9dAQBS/qlgKJy4zG4YTl5nVTcVnJjtxmdlQTlxmVisjm4BaCicuMxvCL8sws/rxVUUzqxu5xTX+bT/rhMz4z47/Rma8P2fO0Fu/d15m/FVf+llm/EC18Quvbhrr5/bMY3/buzczfvTXqt0iaUnZjzctwInLzAaRB+fNrIbc4jKz2qn4I2CduMxsIM/jMrM68lVFM6ufiicuPwHVzGrHLa426P6T7S0d//MXujLjf/TNRzLjvS2dvb66XtN8nhbAdacvy4hm/5u98Bf/MTN+1M8fyIzXnbuKZlYvgW/5MbMacovLzOrGXUUzqx8nLjOrHScuM6sThbuKZlZHvqo4PvT+mzc3jd18QvbztuDgzOhnL8yeM3TE1ntyyj8w/fovX54Zf0PG1/5ob/a7Kmd+48D+06h6iyt35ryk5ZJ2SFrXsO0iSVsl3Z8uZ3a2mmY2pqLgUpIit/xcCSwYZvtXI2JuuqwcJm5mdRQvjXPlLWXJTVwRcTewawzqYmZVMQ5aXM0slfRg2pWc0mwnSUskrZG0Zh8vtHA6Mxsr6i+2lGW0iety4DhgLrAN+EqzHSOiJyLmRcS8iUwa5enMzF4yqsQVEdsjoi8i+oFvAye2t1pmVqrx2FWUNL1h9f3Aumb7mlnN1GBwPneyiqQfAPOBaZK2AJ8D5kuaS5JzHwY+1sE6VsJj72o+KWjKhO6Wyj7iu56nNZzfXvLWzPhD7/tWTgnN/11e8mefzDzyoJ+szSl7nKv4PK7cxBURi4fZfEUH6mJmVVH3xGVmBxZR7hXDIvzMeTMbqM1jXJIWSHpI0iZJ52fs9wFJIWleXplOXGY2VJuuKkrqApYBZwBzgMWS5gyz32HAJ4F7i1TPicvMhmrfdIgTgU0RsTkiXgSuBRYOs9/fAl8E9hYp1InLzIYYQVdx2v47Y9JlyaCiZgCPNaxvSbe9dC7pTcCsiPhx0fp5cL6gl/3xk01j/VR8JLOi4uS5mfEvf+CqzHje937CXYP/hl5y3D8e4NMd8hS/qvhkROSOSTUjaQJwGfDRkRznxGVmA0VbrypuBWY1rM9Mt+13GPA64C5JAK8EVkh6b0SsaVaoE5eZDdW+eVyrgdmSjiVJWGcBH/r9aSKeBqbtX5d0F3BeVtICj3GZ2TDaNR0iInqBpcAqYCNwfUSsl3SxpPeOtn5ucZnZUG2cOZ8+aHTloG0XNtl3fpEynbjMbKCSn/xQhBOXmQ0gqv+yDCcuMxvCicsOaHrzCU1jn736msxjT+nOnkT97zefkRmffc7GpjHPvMvhxGVmtePEZWa1UvLTTYtw4jKzoZy4zKxuqv4gQScuMxvCXUUzqxdPQDWzWnLisgt3/Ouyq9AxWfO0APb+/XNNY2/rfjHz2DuePzQz/sxnZ2TGtfeBzLgNzzPnzayW1F/tzOXEZWYDeYzLzOrIXUUzqx8nLjOrG7e4zKx+nLjMrFba+5afjshNXJJmAVcDR5Lk4Z6I+LqkqcB1wDHAw8CiiNjduapW14Scd458/g+y3+H3bt7czuqMiCYenBn/zaVvyoz/etGyUZ/7tucnZ8a/+cEPZMa11vO0OqEO87iKvOWnFzg3IuYAbwE+LmkOcD5wR0TMBu5I181sPIgotpQkN3FFxLaIuC/9vIfkFUMzgIXA/lcNXwW8r1OVNLOx1a7Xk3XKiMa4JB0DvBG4FzgyIralocdJupJmVnfjaQKqpEOBG4FPRcQz6euyAYiIkIbPv5KWAEsAujmktdqa2Zio+uB8oTdZS5pIkrS+FxE3pZu3S5qexqcDO4Y7NiJ6ImJeRMybyKR21NnMOkz9xZay5CYuJU2rK4CNEXFZQ2gFcHb6+WzglvZXz8zGXFD5wfkiXcWTgY8Av5J0f7rtAuAS4HpJ5wCPAIs6U8VqeGL7EU1j/S2+7Gr3j2dnxl+47RXZBah5aNJpT2QeOqX7+cz4huO/mRlv5TfPm+4Qa9e3ULq1ourTIXITV0T8lOZ/Gqe2tzpmVgl1T1xmdmCpwwRUJy4zGyjCDxI0sxqqdt5y4jKzodxVNLN6CcBdRTOrnWrnLSeuol77mc1NY/Onn5V57F2vvzYz/k9zs+P9czs3RTnvkTx5Z37tHR/LjB/9/a6msYPXrs4p3crSzq6ipAXA14Eu4DsRccmg+F8Bf0HyJJongD+PiEeyyix0y4+ZHVjUH4WW3HKkLmAZcAYwB1icPhar0S+BeRHxeuAG4NK8cp24zGygGMGS70RgU0RsjogXgWtJHon10uki7oyI36Wr9wAz8wp1V9HMBkgmoBbuK06TtKZhvSciehrWZwCPNaxvAU7KKO8c4P/mndSJy8yGKj6s+mREzGvHKSV9GJgHvD1vXycuMxtiBC2uPFuBWQ3rM9NtA88nvRP4r8DbI+KFvEI9xmVmA7V3jGs1MFvSsZIOBs4ieSTW70l6I/A/gfdGxLDP9RvMLS4zG6R99ypGRK+kpcAqkukQyyNivaSLgTURsQL4EnAo8L/TJys/GhHvzSrXiaugvp27msamfLA389g5f/OJzHgrr/jKc8JdSzLjfc9l/y8w89bsRvlrVma/Iqx/797MuFVUGx8SGBErgZWDtl3Y8PmdIy3TicvMBhoPL4Q1swNQiY9lLsKJy8yGqnbecuIys6HUX+2+ohOXmQ0UtPYWlDHgxGVmA4ho5wTUjnDiMrOhnLjGv75nnsmMv/rT92TG3/3pN7ezOgMcxy87VjZUvkdho+XEZWa14jEuM6sjX1U0s5oJdxXNrGYCJy4zq6Fq9xSduMxsKM/jMrP6qXjiyn0CqqRZku6UtEHSekmfTLdfJGmrpPvT5czOV9fMOi4C+vqLLSUp0uLqBc6NiPskHQaslXR7GvtqRHy5c9Uzs1JUvMWVm7giYhuwLf28R9JGklcOmdl4VfHENaKXZUg6BngjcG+6aamkByUtlzSlyTFLJK2RtGYfuS/vMLOyBdAfxZaSFE5ckg4FbgQ+FRHPAJcDxwFzSVpkXxnuuIjoiYh5ETFvIpPaUGUz66yA6C+2lKTQVUVJE0mS1vci4iaAiNjeEP828KOO1NDMxlZQ6sB7EUWuKgq4AtgYEZc1bJ/esNv7gXXtr56ZlSKi2FKSIi2uk4GPAL+SdH+67QJgsaS5JPn5YeBjHamhmY29ig/OF7mq+FNAw4RWDrPNzGrPN1mbWd0E4MfamFntuMVlZvUSlb+q6MRlZgMFRIlztIpw4jKzoUqcFV+EE5eZDeUxLjOrlQhfVTSzGnKLy8zqJYi+vrIrkcmJy8wG2v9Ymwpz4jKzoSo+HWJEDxI0s/EvgOiPQksRkhZIekjSJknnDxOfJOm6NH5v+sDSTE5cZjZQtO9BgpK6gGXAGcAckqfKzBm02znA7oh4NfBV4It55TpxmdkQ0ddXaCngRGBTRGyOiBeBa4GFg/ZZCFyVfr4BODV9DmBTYzrGtYfdT/5D3PBIw6ZpwJNjWYcRqGrdqlovcN1Gq511O7rVAvawe9U/xA3TCu7eLWlNw3pPRPQ0rM8AHmtY3wKcNKiM3+8TEb2SngZeTsZ3MqaJKyJe0bguaU1EzBvLOhRV1bpVtV7guo1W1eoWEQvKrkMedxXNrJO2ArMa1mem24bdR9JBwBHAzqxCnbjMrJNWA7MlHSvpYOAsYMWgfVYAZ6ef/y3wjxHZU/fLnsfVk79Laapat6rWC1y30apy3VqSjlktBVYBXcDyiFgv6WJgTUSsIHkZzzWSNgG7SJJbJuUkNjOzynFX0cxqx4nLzGqnlMSVdwtAmSQ9LOlXku4fND+ljLosl7RD0rqGbVMl3S7pN+nPKRWq20WStqbf3f2SziypbrMk3Slpg6T1kj6Zbi/1u8uoVyW+tzoZ8zGu9BaAfwZOI5mMthpYHBEbxrQiTUh6GJgXEaVPVpT0NuBZ4OqIeF267VJgV0Rckib9KRHxnytSt4uAZyPiy2Ndn0F1mw5Mj4j7JB0GrAXeB3yUEr+7jHotogLfW52U0eIqcguAARFxN8lVlkaNt0dcRfI//phrUrdKiIhtEXFf+nkPsJFkdnap311GvWyEykhcw90CUKX/eAHcJmmtpCVlV2YYR0bEtvTz48CRZVZmGEslPZh2JUvpxjZKnzTwRuBeKvTdDaoXVOx7qzoPzg91SkS8ieRu9o+nXaJKSifpVWk+y+XAccBcYBvwlTIrI+lQ4EbgUxHxTGOszO9umHpV6nurgzISV5FbAEoTEVvTnzuAm0m6tlWyPR0r2T9msqPk+vxeRGyPiL5IXsr3bUr87iRNJEkO34uIm9LNpX93w9WrSt9bXZSRuIrcAlAKSZPTQVMkTQbeBazLPmrMNd4ecTZwS4l1GWB/Uki9n5K+u/SRKFcAGyPisoZQqd9ds3pV5Xurk1JmzqeXe7/GS7cAfGHMKzEMSa8iaWVBcjvU98usm6QfAPNJHnuyHfgc8EPgeuAo4BFgUUSM+SB5k7rNJ+nuBPAw8LGGMaWxrNspwE+AXwH7n3Z3Acl4UmnfXUa9FlOB761OfMuPmdWOB+fNrHacuMysdpy4zKx2nLjMrHacuMysdpy4zKx2nLjMrHb+P+wMXWAQb/8HAAAAAElFTkSuQmCC\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Train"
},
{
"metadata": {
"trusted": true,
"scrolled": true
},
"cell_type": "code",
"source": "max_iter,batch_size,PRINT_EVERY = 5e3,128,100\nseq_len = 28 # sequence length [14/28]\nsess = tf.Session() \nsess.run(tf.global_variables_initializer())\nfor iters in range(int(max_iter)):\n r_idxs = np.random.permutation(n_train)[:batch_size]\n x_train,y_train = train_images[r_idxs,:seq_len,:],train_labels[r_idxs,:]\n istate = np.zeros(shape=(batch_size,rnn_d_state))\n ioutput = np.zeros(shape=(batch_size,rnn_d_state))\n cost_val,_ = sess.run([cost,optm],feed_dict={\n rnn_x:x_train,rnn_istate:istate,rnn_ioutput:ioutput,y:y_train\n })\n if ((iters+1)%PRINT_EVERY) == 0:\n x_test = test_images[:,:seq_len,:]\n istate = np.zeros(shape=(n_test,rnn_d_state))\n ioutput = np.zeros(shape=(n_test,rnn_d_state))\n accr_val = sess.run(accr,feed_dict={ # compute accuracy of the model \n rnn_x:x_test,rnn_istate:istate,rnn_ioutput:ioutput,y:test_labels\n })\n print (\"[%04d/%d] accr:[%.3f]\"%(iters+1,max_iter,accr_val))\nprint (\"Done.\")",
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": "[0100/5000] accr:[0.389]\n[0200/5000] accr:[0.605]\n[0300/5000] accr:[0.712]\n[0400/5000] accr:[0.758]\n[0500/5000] accr:[0.783]\n[0600/5000] accr:[0.812]\n[0700/5000] accr:[0.826]\n[0800/5000] accr:[0.849]\n[0900/5000] accr:[0.864]\n[1000/5000] accr:[0.872]\n[1100/5000] accr:[0.880]\n[1200/5000] accr:[0.883]\n[1300/5000] accr:[0.890]\n[1400/5000] accr:[0.895]\n[1500/5000] accr:[0.902]\n[1600/5000] accr:[0.895]\n[1700/5000] accr:[0.905]\n[1800/5000] accr:[0.903]\n[1900/5000] accr:[0.915]\n[2000/5000] accr:[0.912]\n[2100/5000] accr:[0.918]\n[2200/5000] accr:[0.923]\n[2300/5000] accr:[0.927]\n[2400/5000] accr:[0.925]\n[2500/5000] accr:[0.916]\n[2600/5000] accr:[0.933]\n[2700/5000] accr:[0.931]\n[2800/5000] accr:[0.939]\n[2900/5000] accr:[0.940]\n[3000/5000] accr:[0.940]\n[3100/5000] accr:[0.943]\n[3200/5000] accr:[0.946]\n[3300/5000] accr:[0.946]\n[3400/5000] accr:[0.950]\n[3500/5000] accr:[0.948]\n[3600/5000] accr:[0.949]\n[3700/5000] accr:[0.950]\n[3800/5000] accr:[0.947]\n[3900/5000] accr:[0.944]\n[4000/5000] accr:[0.952]\n[4100/5000] accr:[0.954]\n[4200/5000] accr:[0.956]\n[4300/5000] accr:[0.954]\n[4400/5000] accr:[0.954]\n[4500/5000] accr:[0.952]\n[4600/5000] accr:[0.958]\n[4700/5000] accr:[0.957]\n[4800/5000] accr:[0.959]\n[4900/5000] accr:[0.955]\n[5000/5000] accr:[0.953]\nDone.\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Test with partial observations"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "seq_lens = [14,20,24,28]\nfor seq_len in seq_lens:\n x_test = test_images[:,:seq_len,:]\n istate = np.zeros(shape=(n_test,rnn_d_state))\n ioutput = np.zeros(shape=(n_test,rnn_d_state))\n accr_val = sess.run(accr,feed_dict={ # compute accuracy of the model \n rnn_x:x_test,rnn_istate:istate,rnn_ioutput:ioutput,y:test_labels\n })\n print (\"Accuracy of [%d/28] observations is [%.3f]\"%(seq_len,accr_val))",
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": "Accuracy of [14/28] observations is [0.153]\nAccuracy of [20/28] observations is [0.327]\nAccuracy of [24/28] observations is [0.687]\nAccuracy of [28/28] observations is [0.953]\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.6.7",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"gist": {
"id": "",
"data": {
"description": "vibroptml/scripts/demo_tf_rnn.ipynb",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment