Last active
June 3, 2016 17:37
-
-
Save nzw0301/13dc8c51c513ab8738c01838af5bdddc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using Theano backend.\n" | |
] | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"np.random.seed(13)\n", | |
"\n", | |
"from keras.models import Model\n", | |
"from keras.layers import Dense, Embedding, GRU, Input, merge\n", | |
"from keras.preprocessing.text import Tokenizer, base_filter\n", | |
"from keras.preprocessing import sequence\n", | |
"from keras.regularizers import l2\n", | |
"from keras.utils import np_utils\n", | |
"from keras.utils.visualize_util import model_to_dot, plot\n", | |
"from IPython.display import SVG\n", | |
"\n", | |
"from matplotlib import pyplot as plt\n", | |
"%matplotlib inline\n", | |
"plt.style.use(\"ggplot\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"tweets = []\n", | |
"labels = []\n", | |
"with open(\"./data/tweets.tsv\") as f:\n", | |
" for l in f:\n", | |
" tweet, label = l.strip().split(\"\\t\")\n", | |
" tweets.append(\" \".join(list(tweet)))\n", | |
" labels.append(int(label))\n", | |
"maxlen = 140\n", | |
"\n", | |
"tokenizer = Tokenizer(filters=\"\")\n", | |
"tokenizer.fit_on_texts(tweets)\n", | |
"X_train = tokenizer.texts_to_sequences(tweets)\n", | |
"X_train = sequence.pad_sequences(X_train, maxlen=maxlen)\n", | |
"Y_train = np_utils.to_categorical(labels, len(set(labels)))\n", | |
"V = len(tokenizer.word_index) + 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/svg+xml": [ | |
"<svg height=\"377pt\" viewBox=\"0.00 0.00 528.36 377.00\" width=\"528pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
"<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 373)\">\n", | |
"<title>G</title>\n", | |
"<polygon fill=\"white\" points=\"-4,4 -4,-373 524.362,-373 524.362,4 -4,4\" stroke=\"none\"/>\n", | |
"<!-- 4529127832 -->\n", | |
"<g class=\"node\" id=\"node1\"><title>4529127832</title>\n", | |
"<polygon fill=\"none\" points=\"123.124,-324.5 123.124,-368.5 396.238,-368.5 396.238,-324.5 123.124,-324.5\" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"190.022\" y=\"-342.3\">input_1 (InputLayer)</text>\n", | |
"<polyline fill=\"none\" points=\"256.921,-324.5 256.921,-368.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"284.755\" y=\"-353.3\">input:</text>\n", | |
"<polyline fill=\"none\" points=\"256.921,-346.5 312.59,-346.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"284.755\" y=\"-331.3\">output:</text>\n", | |
"<polyline fill=\"none\" points=\"312.59,-324.5 312.59,-368.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"354.414\" y=\"-353.3\">(None, 140)</text>\n", | |
"<polyline fill=\"none\" points=\"312.59,-346.5 396.238,-346.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"354.414\" y=\"-331.3\">(None, 140)</text>\n", | |
"</g>\n", | |
"<!-- 4529128280 -->\n", | |
"<g class=\"node\" id=\"node2\"><title>4529128280</title>\n", | |
"<polygon fill=\"none\" points=\"94.7344,-243.5 94.7344,-287.5 424.628,-287.5 424.628,-243.5 94.7344,-243.5\" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"179.522\" y=\"-261.3\">embedding_1 (Embedding)</text>\n", | |
"<polyline fill=\"none\" points=\"264.311,-243.5 264.311,-287.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"292.145\" y=\"-272.3\">input:</text>\n", | |
"<polyline fill=\"none\" points=\"264.311,-265.5 319.979,-265.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"292.145\" y=\"-250.3\">output:</text>\n", | |
"<polyline fill=\"none\" points=\"319.979,-243.5 319.979,-287.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"372.304\" y=\"-272.3\">(None, 140)</text>\n", | |
"<polyline fill=\"none\" points=\"319.979,-265.5 424.628,-265.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"372.304\" y=\"-250.3\">(None, 140, 80)</text>\n", | |
"</g>\n", | |
"<!-- 4529127832->4529128280 -->\n", | |
"<g class=\"edge\" id=\"edge1\"><title>4529127832->4529128280</title>\n", | |
"<path d=\"M259.681,-324.329C259.681,-316.183 259.681,-306.699 259.681,-297.797\" fill=\"none\" stroke=\"black\"/>\n", | |
"<polygon fill=\"black\" points=\"263.181,-297.729 259.681,-287.729 256.181,-297.729 263.181,-297.729\" stroke=\"black\"/>\n", | |
"</g>\n", | |
"<!-- 4531654384 -->\n", | |
"<g class=\"node\" id=\"node3\"><title>4531654384</title>\n", | |
"<polygon fill=\"none\" points=\"0,-162.5 0,-206.5 251.362,-206.5 251.362,-162.5 0,-162.5\" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"45.5225\" y=\"-180.3\">gru_1 (GRU)</text>\n", | |
"<polyline fill=\"none\" points=\"91.0449,-162.5 91.0449,-206.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"118.879\" y=\"-191.3\">input:</text>\n", | |
"<polyline fill=\"none\" points=\"91.0449,-184.5 146.714,-184.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"118.879\" y=\"-169.3\">output:</text>\n", | |
"<polyline fill=\"none\" points=\"146.714,-162.5 146.714,-206.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"199.038\" y=\"-191.3\">(None, 140, 80)</text>\n", | |
"<polyline fill=\"none\" points=\"146.714,-184.5 251.362,-184.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"199.038\" y=\"-169.3\">(None, 128)</text>\n", | |
"</g>\n", | |
"<!-- 4529128280->4531654384 -->\n", | |
"<g class=\"edge\" id=\"edge2\"><title>4529128280->4531654384</title>\n", | |
"<path d=\"M223.772,-243.329C207.399,-233.677 187.845,-222.149 170.519,-211.934\" fill=\"none\" stroke=\"black\"/>\n", | |
"<polygon fill=\"black\" points=\"172.082,-208.793 161.69,-206.729 168.527,-214.823 172.082,-208.793\" stroke=\"black\"/>\n", | |
"</g>\n", | |
"<!-- 4528387184 -->\n", | |
"<g class=\"node\" id=\"node4\"><title>4528387184</title>\n", | |
"<polygon fill=\"none\" points=\"269,-162.5 269,-206.5 520.362,-206.5 520.362,-162.5 269,-162.5\" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"314.522\" y=\"-180.3\">gru_2 (GRU)</text>\n", | |
"<polyline fill=\"none\" points=\"360.045,-162.5 360.045,-206.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"387.879\" y=\"-191.3\">input:</text>\n", | |
"<polyline fill=\"none\" points=\"360.045,-184.5 415.714,-184.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"387.879\" y=\"-169.3\">output:</text>\n", | |
"<polyline fill=\"none\" points=\"415.714,-162.5 415.714,-206.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"468.038\" y=\"-191.3\">(None, 140, 80)</text>\n", | |
"<polyline fill=\"none\" points=\"415.714,-184.5 520.362,-184.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"468.038\" y=\"-169.3\">(None, 128)</text>\n", | |
"</g>\n", | |
"<!-- 4529128280->4528387184 -->\n", | |
"<g class=\"edge\" id=\"edge3\"><title>4529128280->4528387184</title>\n", | |
"<path d=\"M295.859,-243.329C312.353,-233.677 332.054,-222.149 349.509,-211.934\" fill=\"none\" stroke=\"black\"/>\n", | |
"<polygon fill=\"black\" points=\"351.541,-214.8 358.404,-206.729 348.005,-208.759 351.541,-214.8\" stroke=\"black\"/>\n", | |
"</g>\n", | |
"<!-- 4528387352 -->\n", | |
"<g class=\"node\" id=\"node5\"><title>4528387352</title>\n", | |
"<polygon fill=\"none\" points=\"91.1113,-81.5 91.1113,-125.5 428.251,-125.5 428.251,-81.5 91.1113,-81.5\" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"148.036\" y=\"-99.3\">merge_1 (Merge)</text>\n", | |
"<polyline fill=\"none\" points=\"204.961,-81.5 204.961,-125.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"232.795\" y=\"-110.3\">input:</text>\n", | |
"<polyline fill=\"none\" points=\"204.961,-103.5 260.63,-103.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"232.795\" y=\"-88.3\">output:</text>\n", | |
"<polyline fill=\"none\" points=\"260.63,-81.5 260.63,-125.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"344.44\" y=\"-110.3\">[(None, 128), (None, 128)]</text>\n", | |
"<polyline fill=\"none\" points=\"260.63,-103.5 428.251,-103.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"343.954\" y=\"-88.3\">(None, 128)</text>\n", | |
"</g>\n", | |
"<!-- 4531654384->4528387352 -->\n", | |
"<g class=\"edge\" id=\"edge4\"><title>4531654384->4528387352</title>\n", | |
"<path d=\"M161.591,-162.329C177.963,-152.677 197.518,-141.149 214.844,-130.934\" fill=\"none\" stroke=\"black\"/>\n", | |
"<polygon fill=\"black\" points=\"216.836,-133.823 223.672,-125.729 213.281,-127.793 216.836,-133.823\" stroke=\"black\"/>\n", | |
"</g>\n", | |
"<!-- 4528387184->4528387352 -->\n", | |
"<g class=\"edge\" id=\"edge5\"><title>4528387184->4528387352</title>\n", | |
"<path d=\"M358.504,-162.329C342.009,-152.677 322.309,-141.149 304.853,-130.934\" fill=\"none\" stroke=\"black\"/>\n", | |
"<polygon fill=\"black\" points=\"306.357,-127.759 295.959,-125.729 302.822,-133.8 306.357,-127.759\" stroke=\"black\"/>\n", | |
"</g>\n", | |
"<!-- 4534657376 -->\n", | |
"<g class=\"node\" id=\"node6\"><title>4534657376</title>\n", | |
"<polygon fill=\"none\" points=\"135.179,-0.5 135.179,-44.5 384.183,-44.5 384.183,-0.5 135.179,-0.5\" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"190.022\" y=\"-18.3\">dense_1 (Dense)</text>\n", | |
"<polyline fill=\"none\" points=\"244.866,-0.5 244.866,-44.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"272.7\" y=\"-29.3\">input:</text>\n", | |
"<polyline fill=\"none\" points=\"244.866,-22.5 300.535,-22.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"272.7\" y=\"-7.3\">output:</text>\n", | |
"<polyline fill=\"none\" points=\"300.535,-0.5 300.535,-44.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"342.359\" y=\"-29.3\">(None, 128)</text>\n", | |
"<polyline fill=\"none\" points=\"300.535,-22.5 384.183,-22.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"342.359\" y=\"-7.3\">(None, 30)</text>\n", | |
"</g>\n", | |
"<!-- 4528387352->4534657376 -->\n", | |
"<g class=\"edge\" id=\"edge6\"><title>4528387352->4534657376</title>\n", | |
"<path d=\"M259.681,-81.3294C259.681,-73.1826 259.681,-63.6991 259.681,-54.7971\" fill=\"none\" stroke=\"black\"/>\n", | |
"<polygon fill=\"black\" points=\"263.181,-54.729 259.681,-44.729 256.181,-54.729 263.181,-54.729\" stroke=\"black\"/>\n", | |
"</g>\n", | |
"</g>\n", | |
"</svg>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.SVG object>" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"l2_coef = 0.001\n", | |
"tweet = Input(shape=(maxlen,), dtype='int32')\n", | |
"x = Embedding(V, 80, input_length=maxlen, W_regularizer=l2(l=l2_coef))(tweet)\n", | |
"f = GRU(128, return_sequences=False, W_regularizer=l2(l=l2_coef), b_regularizer=l2(l=l2_coef), U_regularizer=l2(l=l2_coef))(x)\n", | |
"b = GRU(128, go_backwards=True, return_sequences=False, W_regularizer=l2(l=l2_coef), b_regularizer=l2(l=l2_coef), U_regularizer=l2(l=l2_coef))(x)\n", | |
"x = merge([f, b], mode=\"sum\")\n", | |
"x = Dense(len(set(labels)), W_regularizer=l2(l=l2_coef), activation=\"softmax\")(x)\n", | |
"\n", | |
"tweet2vec = Model(input=tweet, output=x)\n", | |
"\n", | |
"tweet2vec.compile(loss='categorical_crossentropy',\n", | |
" optimizer='RMSprop',\n", | |
" metrics=['accuracy'])\n", | |
"\n", | |
"SVG(model_to_dot(tweet2vec, show_shapes=True).create(prog='dot', format='svg'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Train on 337 samples, validate on 38 samples\n", | |
"Epoch 1/10\n", | |
"337/337 [==============================] - 6s - loss: 4.7598 - acc: 0.0742 - val_loss: 3.2707 - val_acc: 0.0000e+00\n", | |
"Epoch 2/10\n", | |
"337/337 [==============================] - 6s - loss: 4.4159 - acc: 0.1098 - val_loss: 3.2534 - val_acc: 0.1579\n", | |
"Epoch 3/10\n", | |
"337/337 [==============================] - 6s - loss: 4.1885 - acc: 0.1335 - val_loss: 3.2109 - val_acc: 0.1053\n", | |
"Epoch 4/10\n", | |
"337/337 [==============================] - 6s - loss: 3.9170 - acc: 0.1721 - val_loss: 3.2325 - val_acc: 0.0526\n", | |
"Epoch 5/10\n", | |
"337/337 [==============================] - 6s - loss: 3.7203 - acc: 0.1573 - val_loss: 3.5765 - val_acc: 0.0263\n", | |
"Epoch 6/10\n", | |
"337/337 [==============================] - 6s - loss: 3.5322 - acc: 0.1840 - val_loss: 3.6293 - val_acc: 0.0263\n", | |
"Epoch 7/10\n", | |
"337/337 [==============================] - 6s - loss: 3.4264 - acc: 0.1751 - val_loss: 3.1155 - val_acc: 0.1053\n", | |
"Epoch 8/10\n", | |
"337/337 [==============================] - 6s - loss: 3.2845 - acc: 0.1988 - val_loss: 3.1695 - val_acc: 0.1053\n", | |
"Epoch 9/10\n", | |
"337/337 [==============================] - 6s - loss: 3.2185 - acc: 0.1810 - val_loss: 3.2305 - val_acc: 0.0526\n", | |
"Epoch 10/10\n", | |
"337/337 [==============================] - 6s - loss: 3.1182 - acc: 0.2196 - val_loss: 3.2100 - val_acc: 0.0526\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<keras.callbacks.History at 0x113231748>" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tweet2vec.fit(X_train, Y_train, nb_epoch=10, batch_size=32, validation_split=0.1)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"collapsed": true | |
}, | |
"source": [ | |
"## ref http://arxiv.org/abs/1605.03481" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment