Forked from j-min/RNN_hunkim's_tutorial_BasicRNNCell.ipynb
Created
January 22, 2017 02:09
-
-
Save LKH-1/aa09dff3b5fa35f849a703b137c897cd to your computer and use it in GitHub Desktop.
TensorFlow 0.9 implementation of BasicRNNCell based on hunkim's tutorial
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### BasicRNNCell\n", | |
| "#### TensorFlow 0.9 implementation based on hunkim's tutorial\n", | |
| "https://hunkim.github.io/ml/\n", | |
| "\n", | |
| "https://www.youtube.com/watch?v=A8wJYfDUYCk&feature=youtu.be" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import tensorflow as tf\n", | |
| "import numpy as np" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "{'o': 3, 'l': 2, 'e': 1, 'h': 0}\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "char_rdic = ['h', 'e', 'l', 'o'] # id -> char\n", | |
| "char_dic = {w : i for i, w in enumerate(char_rdic)} # char -> id\n", | |
| "print (char_dic)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[0, 1, 2, 2, 3]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "ground_truth = [char_dic[c] for c in 'hello']\n", | |
| "print (ground_truth)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "x_data = np.array([[1,0,0,0], # h\n", | |
| " [0,1,0,0], # e\n", | |
| " [0,0,1,0], # l\n", | |
| " [0,0,1,0]], # l\n", | |
| " dtype = 'f')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Tensor(\"one_hot:0\", shape=(4, 4), dtype=float32)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "x_data = tf.one_hot(ground_truth[:-1], len(char_dic), 1.0, 0.0, -1)\n", | |
| "print(x_data)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Configuration\n", | |
| "rnn_size = len(char_dic) # 4\n", | |
| "batch_size = 1\n", | |
| "output_size = 4" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "<tensorflow.python.ops.rnn_cell.BasicRNNCell object at 0x7effb759c9e8>\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# RNN Model\n", | |
| "rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units = rnn_size,\n", | |
| " input_size = None, # deprecated at tensorflow 0.9\n", | |
| " #activation = tanh,\n", | |
| " )\n", | |
| "print(rnn_cell)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Tensor(\"zeros:0\", shape=(1, 4), dtype=float32)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "initial_state = rnn_cell.zero_state(batch_size, tf.float32)\n", | |
| "print(initial_state)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Tensor(\"zeros_1:0\", shape=(1, 4), dtype=float32)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "initial_state_1 = tf.zeros([batch_size, rnn_cell.state_size]) # 위 코드와 같은 결과\n", | |
| "print(initial_state_1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[<tf.Tensor 'split:0' shape=(1, 4) dtype=float32>, <tf.Tensor 'split:1' shape=(1, 4) dtype=float32>, <tf.Tensor 'split:2' shape=(1, 4) dtype=float32>, <tf.Tensor 'split:3' shape=(1, 4) dtype=float32>]\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "'\\n[[1,0,0,0]] # h\\n[[0,1,0,0]] # e\\n[[0,0,1,0]] # l\\n[[0,0,1,0]] # l\\n'" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "x_split = tf.split(0, len(char_dic), x_data) # 가로축으로 4개로 split\n", | |
| "print(x_split)\n", | |
| "\"\"\"\n", | |
| "[[1,0,0,0]] # h\n", | |
| "[[0,1,0,0]] # e\n", | |
| "[[0,0,1,0]] # l\n", | |
| "[[0,0,1,0]] # l\n", | |
| "\"\"\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "outputs, state = tf.nn.rnn(cell = rnn_cell, inputs = x_split, initial_state = initial_state)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[<tf.Tensor 'RNN/BasicRNNCell/Tanh:0' shape=(1, 4) dtype=float32>, <tf.Tensor 'RNN/BasicRNNCell_1/Tanh:0' shape=(1, 4) dtype=float32>, <tf.Tensor 'RNN/BasicRNNCell_2/Tanh:0' shape=(1, 4) dtype=float32>, <tf.Tensor 'RNN/BasicRNNCell_3/Tanh:0' shape=(1, 4) dtype=float32>]\n", | |
| "Tensor(\"RNN/BasicRNNCell_3/Tanh:0\", shape=(1, 4), dtype=float32)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print (outputs)\n", | |
| "print (state)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "'\\n[[logit from 1st output],\\n[logit from 2nd output],\\n[logit from 3rd output],\\n[logit from 4th output]]\\n'" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "logits = tf.reshape(tf.concat(1, outputs), # shape = 1 x 16\n", | |
| " [-1, rnn_size]) # shape = 4 x 4\n", | |
| "logits.get_shape()\n", | |
| "\"\"\"\n", | |
| "[[logit from 1st output],\n", | |
| "[logit from 2nd output],\n", | |
| "[logit from 3rd output],\n", | |
| "[logit from 4th output]]\n", | |
| "\"\"\"\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "TensorShape([Dimension(4)])" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "targets = tf.reshape(ground_truth[1:], [-1]) # a shape of [-1] flattens into 1-D\n", | |
| "targets.get_shape()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "weights = tf.ones([len(char_dic) * batch_size])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "loss = tf.nn.seq2seq.sequence_loss_by_example([logits], [targets], [weights])\n", | |
| "cost = tf.reduce_sum(loss) / batch_size\n", | |
| "train_op = tf.train.RMSPropOptimizer(0.01, 0.9).minimize(cost)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Launch the graph in a session\n", | |
| "with tf.Session() as sess:\n", | |
| " tf.initialize_all_variables().run()\n", | |
| " for i in range(100):\n", | |
| " sess.run(train_op)\n", | |
| " result = sess.run(tf.argmax(logits, 1))\n", | |
| " print(result, [char_rdic[t] for t in result]) " | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "anaconda-cloud": {}, | |
| "kernelspec": { | |
| "display_name": "Python [tensorflow]", | |
| "language": "python", | |
| "name": "Python [tensorflow]" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.5.2" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment