Skip to content

Instantly share code, notes, and snippets.

@nzw0301
Last active June 3, 2016 17:37
Show Gist options
  • Save nzw0301/13dc8c51c513ab8738c01838af5bdddc to your computer and use it in GitHub Desktop.
Save nzw0301/13dc8c51c513ab8738c01838af5bdddc to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using Theano backend.\n"
]
}
],
"source": [
"import numpy as np\n",
"np.random.seed(13)\n",
"\n",
"from keras.models import Model\n",
"from keras.layers import Dense, Embedding, GRU, Input, merge\n",
"from keras.preprocessing.text import Tokenizer, base_filter\n",
"from keras.preprocessing import sequence\n",
"from keras.regularizers import l2\n",
"from keras.utils import np_utils\n",
"from keras.utils.visualize_util import model_to_dot, plot\n",
"from IPython.display import SVG\n",
"\n",
"from matplotlib import pyplot as plt\n",
"%matplotlib inline\n",
"plt.style.use(\"ggplot\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"tweets = []\n",
"labels = []\n",
"with open(\"./data/tweets.tsv\") as f:\n",
" for l in f:\n",
" tweet, label = l.strip().split(\"\\t\")\n",
" tweets.append(\" \".join(list(tweet)))\n",
" labels.append(int(label))\n",
"maxlen = 140\n",
"\n",
"tokenizer = Tokenizer(filters=\"\")\n",
"tokenizer.fit_on_texts(tweets)\n",
"X_train = tokenizer.texts_to_sequences(tweets)\n",
"X_train = sequence.pad_sequences(X_train, maxlen=maxlen)\n",
"Y_train = np_utils.to_categorical(labels, len(set(labels)))\n",
"V = len(tokenizer.word_index) + 1"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/svg+xml": [
"<svg height=\"377pt\" viewBox=\"0.00 0.00 528.36 377.00\" width=\"528pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 373)\">\n",
"<title>G</title>\n",
"<polygon fill=\"white\" points=\"-4,4 -4,-373 524.362,-373 524.362,4 -4,4\" stroke=\"none\"/>\n",
"<!-- 4529127832 -->\n",
"<g class=\"node\" id=\"node1\"><title>4529127832</title>\n",
"<polygon fill=\"none\" points=\"123.124,-324.5 123.124,-368.5 396.238,-368.5 396.238,-324.5 123.124,-324.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"190.022\" y=\"-342.3\">input_1 (InputLayer)</text>\n",
"<polyline fill=\"none\" points=\"256.921,-324.5 256.921,-368.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"284.755\" y=\"-353.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"256.921,-346.5 312.59,-346.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"284.755\" y=\"-331.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"312.59,-324.5 312.59,-368.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"354.414\" y=\"-353.3\">(None, 140)</text>\n",
"<polyline fill=\"none\" points=\"312.59,-346.5 396.238,-346.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"354.414\" y=\"-331.3\">(None, 140)</text>\n",
"</g>\n",
"<!-- 4529128280 -->\n",
"<g class=\"node\" id=\"node2\"><title>4529128280</title>\n",
"<polygon fill=\"none\" points=\"94.7344,-243.5 94.7344,-287.5 424.628,-287.5 424.628,-243.5 94.7344,-243.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"179.522\" y=\"-261.3\">embedding_1 (Embedding)</text>\n",
"<polyline fill=\"none\" points=\"264.311,-243.5 264.311,-287.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"292.145\" y=\"-272.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"264.311,-265.5 319.979,-265.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"292.145\" y=\"-250.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"319.979,-243.5 319.979,-287.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"372.304\" y=\"-272.3\">(None, 140)</text>\n",
"<polyline fill=\"none\" points=\"319.979,-265.5 424.628,-265.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"372.304\" y=\"-250.3\">(None, 140, 80)</text>\n",
"</g>\n",
"<!-- 4529127832&#45;&gt;4529128280 -->\n",
"<g class=\"edge\" id=\"edge1\"><title>4529127832-&gt;4529128280</title>\n",
"<path d=\"M259.681,-324.329C259.681,-316.183 259.681,-306.699 259.681,-297.797\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"263.181,-297.729 259.681,-287.729 256.181,-297.729 263.181,-297.729\" stroke=\"black\"/>\n",
"</g>\n",
"<!-- 4531654384 -->\n",
"<g class=\"node\" id=\"node3\"><title>4531654384</title>\n",
"<polygon fill=\"none\" points=\"0,-162.5 0,-206.5 251.362,-206.5 251.362,-162.5 0,-162.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"45.5225\" y=\"-180.3\">gru_1 (GRU)</text>\n",
"<polyline fill=\"none\" points=\"91.0449,-162.5 91.0449,-206.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"118.879\" y=\"-191.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"91.0449,-184.5 146.714,-184.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"118.879\" y=\"-169.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"146.714,-162.5 146.714,-206.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"199.038\" y=\"-191.3\">(None, 140, 80)</text>\n",
"<polyline fill=\"none\" points=\"146.714,-184.5 251.362,-184.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"199.038\" y=\"-169.3\">(None, 128)</text>\n",
"</g>\n",
"<!-- 4529128280&#45;&gt;4531654384 -->\n",
"<g class=\"edge\" id=\"edge2\"><title>4529128280-&gt;4531654384</title>\n",
"<path d=\"M223.772,-243.329C207.399,-233.677 187.845,-222.149 170.519,-211.934\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"172.082,-208.793 161.69,-206.729 168.527,-214.823 172.082,-208.793\" stroke=\"black\"/>\n",
"</g>\n",
"<!-- 4528387184 -->\n",
"<g class=\"node\" id=\"node4\"><title>4528387184</title>\n",
"<polygon fill=\"none\" points=\"269,-162.5 269,-206.5 520.362,-206.5 520.362,-162.5 269,-162.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"314.522\" y=\"-180.3\">gru_2 (GRU)</text>\n",
"<polyline fill=\"none\" points=\"360.045,-162.5 360.045,-206.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"387.879\" y=\"-191.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"360.045,-184.5 415.714,-184.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"387.879\" y=\"-169.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"415.714,-162.5 415.714,-206.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"468.038\" y=\"-191.3\">(None, 140, 80)</text>\n",
"<polyline fill=\"none\" points=\"415.714,-184.5 520.362,-184.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"468.038\" y=\"-169.3\">(None, 128)</text>\n",
"</g>\n",
"<!-- 4529128280&#45;&gt;4528387184 -->\n",
"<g class=\"edge\" id=\"edge3\"><title>4529128280-&gt;4528387184</title>\n",
"<path d=\"M295.859,-243.329C312.353,-233.677 332.054,-222.149 349.509,-211.934\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"351.541,-214.8 358.404,-206.729 348.005,-208.759 351.541,-214.8\" stroke=\"black\"/>\n",
"</g>\n",
"<!-- 4528387352 -->\n",
"<g class=\"node\" id=\"node5\"><title>4528387352</title>\n",
"<polygon fill=\"none\" points=\"91.1113,-81.5 91.1113,-125.5 428.251,-125.5 428.251,-81.5 91.1113,-81.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"148.036\" y=\"-99.3\">merge_1 (Merge)</text>\n",
"<polyline fill=\"none\" points=\"204.961,-81.5 204.961,-125.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"232.795\" y=\"-110.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"204.961,-103.5 260.63,-103.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"232.795\" y=\"-88.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"260.63,-81.5 260.63,-125.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"344.44\" y=\"-110.3\">[(None, 128), (None, 128)]</text>\n",
"<polyline fill=\"none\" points=\"260.63,-103.5 428.251,-103.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"343.954\" y=\"-88.3\">(None, 128)</text>\n",
"</g>\n",
"<!-- 4531654384&#45;&gt;4528387352 -->\n",
"<g class=\"edge\" id=\"edge4\"><title>4531654384-&gt;4528387352</title>\n",
"<path d=\"M161.591,-162.329C177.963,-152.677 197.518,-141.149 214.844,-130.934\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"216.836,-133.823 223.672,-125.729 213.281,-127.793 216.836,-133.823\" stroke=\"black\"/>\n",
"</g>\n",
"<!-- 4528387184&#45;&gt;4528387352 -->\n",
"<g class=\"edge\" id=\"edge5\"><title>4528387184-&gt;4528387352</title>\n",
"<path d=\"M358.504,-162.329C342.009,-152.677 322.309,-141.149 304.853,-130.934\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"306.357,-127.759 295.959,-125.729 302.822,-133.8 306.357,-127.759\" stroke=\"black\"/>\n",
"</g>\n",
"<!-- 4534657376 -->\n",
"<g class=\"node\" id=\"node6\"><title>4534657376</title>\n",
"<polygon fill=\"none\" points=\"135.179,-0.5 135.179,-44.5 384.183,-44.5 384.183,-0.5 135.179,-0.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"190.022\" y=\"-18.3\">dense_1 (Dense)</text>\n",
"<polyline fill=\"none\" points=\"244.866,-0.5 244.866,-44.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"272.7\" y=\"-29.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"244.866,-22.5 300.535,-22.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"272.7\" y=\"-7.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"300.535,-0.5 300.535,-44.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"342.359\" y=\"-29.3\">(None, 128)</text>\n",
"<polyline fill=\"none\" points=\"300.535,-22.5 384.183,-22.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"342.359\" y=\"-7.3\">(None, 30)</text>\n",
"</g>\n",
"<!-- 4528387352&#45;&gt;4534657376 -->\n",
"<g class=\"edge\" id=\"edge6\"><title>4528387352-&gt;4534657376</title>\n",
"<path d=\"M259.681,-81.3294C259.681,-73.1826 259.681,-63.6991 259.681,-54.7971\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"263.181,-54.729 259.681,-44.729 256.181,-54.729 263.181,-54.729\" stroke=\"black\"/>\n",
"</g>\n",
"</g>\n",
"</svg>"
],
"text/plain": [
"<IPython.core.display.SVG object>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"l2_coef = 0.001\n",
"tweet = Input(shape=(maxlen,), dtype='int32')\n",
"x = Embedding(V, 80, input_length=maxlen, W_regularizer=l2(l=l2_coef))(tweet)\n",
"f = GRU(128, return_sequences=False, W_regularizer=l2(l=l2_coef), b_regularizer=l2(l=l2_coef), U_regularizer=l2(l=l2_coef))(x)\n",
"b = GRU(128, go_backwards=True, return_sequences=False, W_regularizer=l2(l=l2_coef), b_regularizer=l2(l=l2_coef), U_regularizer=l2(l=l2_coef))(x)\n",
"x = merge([f, b], mode=\"sum\")\n",
"x = Dense(len(set(labels)), W_regularizer=l2(l=l2_coef), activation=\"softmax\")(x)\n",
"\n",
"tweet2vec = Model(input=tweet, output=x)\n",
"\n",
"tweet2vec.compile(loss='categorical_crossentropy',\n",
" optimizer='RMSprop',\n",
" metrics=['accuracy'])\n",
"\n",
"SVG(model_to_dot(tweet2vec, show_shapes=True).create(prog='dot', format='svg'))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 337 samples, validate on 38 samples\n",
"Epoch 1/10\n",
"337/337 [==============================] - 6s - loss: 4.7598 - acc: 0.0742 - val_loss: 3.2707 - val_acc: 0.0000e+00\n",
"Epoch 2/10\n",
"337/337 [==============================] - 6s - loss: 4.4159 - acc: 0.1098 - val_loss: 3.2534 - val_acc: 0.1579\n",
"Epoch 3/10\n",
"337/337 [==============================] - 6s - loss: 4.1885 - acc: 0.1335 - val_loss: 3.2109 - val_acc: 0.1053\n",
"Epoch 4/10\n",
"337/337 [==============================] - 6s - loss: 3.9170 - acc: 0.1721 - val_loss: 3.2325 - val_acc: 0.0526\n",
"Epoch 5/10\n",
"337/337 [==============================] - 6s - loss: 3.7203 - acc: 0.1573 - val_loss: 3.5765 - val_acc: 0.0263\n",
"Epoch 6/10\n",
"337/337 [==============================] - 6s - loss: 3.5322 - acc: 0.1840 - val_loss: 3.6293 - val_acc: 0.0263\n",
"Epoch 7/10\n",
"337/337 [==============================] - 6s - loss: 3.4264 - acc: 0.1751 - val_loss: 3.1155 - val_acc: 0.1053\n",
"Epoch 8/10\n",
"337/337 [==============================] - 6s - loss: 3.2845 - acc: 0.1988 - val_loss: 3.1695 - val_acc: 0.1053\n",
"Epoch 9/10\n",
"337/337 [==============================] - 6s - loss: 3.2185 - acc: 0.1810 - val_loss: 3.2305 - val_acc: 0.0526\n",
"Epoch 10/10\n",
"337/337 [==============================] - 6s - loss: 3.1182 - acc: 0.2196 - val_loss: 3.2100 - val_acc: 0.0526\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x113231748>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tweet2vec.fit(X_train, Y_train, nb_epoch=10, batch_size=32, validation_split=0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"## ref http://arxiv.org/abs/1605.03481"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment