Skip to content

Instantly share code, notes, and snippets.

@nzw0301
Last active May 3, 2016 12:19
Show Gist options
  • Save nzw0301/eea37f82f03a5f85dafc4314d395b2b3 to your computer and use it in GitHub Desktop.
Save nzw0301/eea37f82f03a5f85dafc4314d395b2b3 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using Theano backend.\n"
]
}
],
"source": [
"from keras.models import Sequential, model_from_json\n",
"from keras.layers import Dense, Activation, Embedding, GRU, Merge, RepeatVector, TimeDistributed\n",
"from keras.optimizers import Adadelta\n",
"from keras.utils import np_utils\n",
"from keras.utils.visualize_util import model_to_dot, plot\n",
"from keras.preprocessing import sequence\n",
"from keras.preprocessing.text import Tokenizer, base_filter\n",
"import numpy as np\n",
"import json\n",
"from IPython.display import SVG, display\n",
"\n",
"from matplotlib import pyplot as plt\n",
"%matplotlib inline\n",
"plt.style.use(\"ggplot\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"np.random.seed(13)\n",
"filters = base_filter() + \"「」・。、()?! '\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# # input\n",
"min_count = 1\n",
"ja_docs = open(\"data/ja.txt\").readlines()[110889:110909]\n",
"ja_tokenizer = Tokenizer(filters=filters)\n",
"ja_tokenizer.fit_on_texts(ja_docs)\n",
"\n",
"nb_words = len(ja_tokenizer.word_index) - len(set([k for k, v in ja_tokenizer.word_counts.items() if v < min_count]))\n",
"ja_tokenizer = Tokenizer(filters=filters, nb_words=nb_words)\n",
"ja_tokenizer.fit_on_texts(ja_docs)\n",
"\n",
"ja_docs = ja_tokenizer.texts_to_sequences(ja_docs)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# target sentence\n",
"en_docs = open(\"data/en.txt\").readlines()[110889:110909]\n",
"en_docs = list(map(lambda x: \"GOS \" + x + \" EOS\" , en_docs))\n",
"en_tokenizer = Tokenizer(filters=filters)\n",
"en_tokenizer.fit_on_texts(en_docs)\n",
"\n",
"nb_words = len(en_tokenizer.word_index) - len(set([k for k, v in en_tokenizer.word_counts.items() if v < min_count]))\n",
"en_tokenizer = Tokenizer(filters=filters, nb_words=nb_words)\n",
"en_tokenizer.fit_on_texts(en_docs)\n",
"\n",
"en_docs = en_tokenizer.texts_to_sequences(en_docs)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"encoder_maxlen = max([len(x) for x in ja_docs])\n",
"decoder_maxlen = max([len(x)-1 for x in en_docs])\n",
"encoder_vocab_size = len(ja_tokenizer.word_index)+1\n",
"decoder_vocab_size = len(en_tokenizer.word_index)+1\n",
"\n",
"def gen_training_data(X, Y, encoder_maxlen, decoder_maxlen, V, samples_size):\n",
" encoder_inputs = []\n",
" decoder_inputs = []\n",
" next_words = []\n",
" for i in range(len(X)):\n",
" x_doc, y_doc = X[i], Y[i]\n",
" for j in range(1, len(y_doc)):\n",
" decoder_inputs.append(y_doc[0:j])\n",
" next_words.append(y_doc[j])\n",
" encoder_inputs.append(x_doc)\n",
" if len(next_words) == samples_size: \n",
" labels = np_utils.to_categorical(next_words, V)\n",
" encoder_inputs = sequence.pad_sequences(encoder_inputs, maxlen=encoder_maxlen)\n",
" decoder_inputs = sequence.pad_sequences(decoder_inputs, maxlen=decoder_maxlen)\n",
" yield ([encoder_inputs, decoder_inputs], labels)\n",
" encoder_inputs = []\n",
" decoder_inputs = []\n",
" next_words = []\n",
" labels = np_utils.to_categorical(next_words, V)\n",
" encoder_inputs = sequence.pad_sequences(encoder_inputs, maxlen=encoder_maxlen)\n",
" decoder_inputs = sequence.pad_sequences(decoder_inputs, maxlen=decoder_maxlen)\n",
"\n",
" yield ([encoder_inputs, decoder_inputs], labels)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# encoder one doc to one repeated vector\n",
"encoder = Sequential()\n",
"encoder.add(Embedding(encoder_vocab_size, 128, input_length=encoder_maxlen))\n",
"encoder.add(GRU(128, return_sequences=False))\n",
"encoder.add(RepeatVector(decoder_maxlen))\n",
"\n",
"# decoder inputs\n",
"decoder_input = Sequential()\n",
"decoder_input.add(Embedding(decoder_vocab_size, 128, input_length=decoder_maxlen))\n",
"decoder_input.add(GRU(output_dim=128, return_sequences=True))\n",
"decoder_input.add(TimeDistributed(Dense(128)))\n",
"\n",
"model = Sequential()\n",
"model.add(Merge([encoder, decoder_input], mode='concat', concat_axis=-1))\n",
"model.add(GRU(128, return_sequences=False))\n",
"model.add(Dense(decoder_vocab_size))\n",
"model.add(Activation('softmax'))\n",
"\n",
"model.compile(loss='categorical_crossentropy', optimizer='adadelta')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"encoder\n"
]
},
{
"data": {
"image/svg+xml": [
"<svg height=\"296pt\" viewBox=\"0.00 0.00 355.72 296.00\" width=\"356pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 292)\">\n",
"<title>G</title>\n",
"<polygon fill=\"white\" points=\"-4,4 -4,-292 351.722,-292 351.722,4 -4,4\" stroke=\"none\"/>\n",
"<!-- 4591283112 -->\n",
"<g class=\"node\" id=\"node1\"><title>4591283112</title>\n",
"<polygon fill=\"none\" points=\"6.2002,-243.5 6.2002,-287.5 341.521,-287.5 341.521,-243.5 6.2002,-243.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"107.702\" y=\"-261.3\">embedding_input_1 (InputLayer)</text>\n",
"<polyline fill=\"none\" points=\"209.204,-243.5 209.204,-287.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"237.039\" y=\"-272.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"209.204,-265.5 264.873,-265.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"237.039\" y=\"-250.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"264.873,-243.5 264.873,-287.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"303.197\" y=\"-272.3\">(None, 56)</text>\n",
"<polyline fill=\"none\" points=\"264.873,-265.5 341.521,-265.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"303.197\" y=\"-250.3\">(None, 56)</text>\n",
"</g>\n",
"<!-- 4591282944 -->\n",
"<g class=\"node\" id=\"node2\"><title>4591282944</title>\n",
"<polygon fill=\"none\" points=\"8.91406,-162.5 8.91406,-206.5 338.808,-206.5 338.808,-162.5 8.91406,-162.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"93.7021\" y=\"-180.3\">embedding_1 (Embedding)</text>\n",
"<polyline fill=\"none\" points=\"178.49,-162.5 178.49,-206.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"206.325\" y=\"-191.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"178.49,-184.5 234.159,-184.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"206.325\" y=\"-169.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"234.159,-162.5 234.159,-206.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"286.483\" y=\"-191.3\">(None, 56)</text>\n",
"<polyline fill=\"none\" points=\"234.159,-184.5 338.808,-184.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"286.483\" y=\"-169.3\">(None, 56, 128)</text>\n",
"</g>\n",
"<!-- 4591283112&#45;&gt;4591282944 -->\n",
"<g class=\"edge\" id=\"edge1\"><title>4591283112-&gt;4591282944</title>\n",
"<path d=\"M173.861,-243.329C173.861,-235.183 173.861,-225.699 173.861,-216.797\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"177.361,-216.729 173.861,-206.729 170.361,-216.729 177.361,-216.729\" stroke=\"black\"/>\n",
"</g>\n",
"<!-- 4591283000 -->\n",
"<g class=\"node\" id=\"node3\"><title>4591283000</title>\n",
"<polygon fill=\"none\" points=\"48.1797,-81.5 48.1797,-125.5 299.542,-125.5 299.542,-81.5 48.1797,-81.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"93.7021\" y=\"-99.3\">gru_1 (GRU)</text>\n",
"<polyline fill=\"none\" points=\"139.225,-81.5 139.225,-125.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"167.059\" y=\"-110.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"139.225,-103.5 194.894,-103.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"167.059\" y=\"-88.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"194.894,-81.5 194.894,-125.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"247.218\" y=\"-110.3\">(None, 56, 128)</text>\n",
"<polyline fill=\"none\" points=\"194.894,-103.5 299.542,-103.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"247.218\" y=\"-88.3\">(None, 128)</text>\n",
"</g>\n",
"<!-- 4591282944&#45;&gt;4591283000 -->\n",
"<g class=\"edge\" id=\"edge2\"><title>4591282944-&gt;4591283000</title>\n",
"<path d=\"M173.861,-162.329C173.861,-154.183 173.861,-144.699 173.861,-135.797\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"177.361,-135.729 173.861,-125.729 170.361,-135.729 177.361,-135.729\" stroke=\"black\"/>\n",
"</g>\n",
"<!-- 4613172528 -->\n",
"<g class=\"node\" id=\"node4\"><title>4613172528</title>\n",
"<polygon fill=\"none\" points=\"0,-0.5 0,-44.5 347.722,-44.5 347.722,-0.5 0,-0.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"93.7021\" y=\"-18.3\">repeatvector_1 (RepeatVector)</text>\n",
"<polyline fill=\"none\" points=\"187.404,-0.5 187.404,-44.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"215.239\" y=\"-29.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"187.404,-22.5 243.073,-22.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"215.239\" y=\"-7.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"243.073,-0.5 243.073,-44.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"295.397\" y=\"-29.3\">(None, 128)</text>\n",
"<polyline fill=\"none\" points=\"243.073,-22.5 347.722,-22.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"295.397\" y=\"-7.3\">(None, 33, 128)</text>\n",
"</g>\n",
"<!-- 4591283000&#45;&gt;4613172528 -->\n",
"<g class=\"edge\" id=\"edge3\"><title>4591283000-&gt;4613172528</title>\n",
"<path d=\"M173.861,-81.3294C173.861,-73.1826 173.861,-63.6991 173.861,-54.7971\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"177.361,-54.729 173.861,-44.729 170.361,-54.729 177.361,-54.729\" stroke=\"black\"/>\n",
"</g>\n",
"</g>\n",
"</svg>"
],
"text/plain": [
"<IPython.core.display.SVG object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"decoder input\n"
]
},
{
"data": {
"image/svg+xml": [
"<svg height=\"296pt\" viewBox=\"0.00 0.00 387.96 296.00\" width=\"388pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 292)\">\n",
"<title>G</title>\n",
"<polygon fill=\"white\" points=\"-4,4 -4,-292 383.96,-292 383.96,4 -4,4\" stroke=\"none\"/>\n",
"<!-- 4464668856 -->\n",
"<g class=\"node\" id=\"node1\"><title>4464668856</title>\n",
"<polygon fill=\"none\" points=\"22.3193,-243.5 22.3193,-287.5 357.641,-287.5 357.641,-243.5 22.3193,-243.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"123.821\" y=\"-261.3\">embedding_input_2 (InputLayer)</text>\n",
"<polyline fill=\"none\" points=\"225.323,-243.5 225.323,-287.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"253.158\" y=\"-272.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"225.323,-265.5 280.992,-265.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"253.158\" y=\"-250.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"280.992,-243.5 280.992,-287.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"319.316\" y=\"-272.3\">(None, 33)</text>\n",
"<polyline fill=\"none\" points=\"280.992,-265.5 357.641,-265.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"319.316\" y=\"-250.3\">(None, 33)</text>\n",
"</g>\n",
"<!-- 4464668800 -->\n",
"<g class=\"node\" id=\"node2\"><title>4464668800</title>\n",
"<polygon fill=\"none\" points=\"25.0332,-162.5 25.0332,-206.5 354.927,-206.5 354.927,-162.5 25.0332,-162.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"109.821\" y=\"-180.3\">embedding_2 (Embedding)</text>\n",
"<polyline fill=\"none\" points=\"194.609,-162.5 194.609,-206.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"222.444\" y=\"-191.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"194.609,-184.5 250.278,-184.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"222.444\" y=\"-169.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"250.278,-162.5 250.278,-206.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"302.603\" y=\"-191.3\">(None, 33)</text>\n",
"<polyline fill=\"none\" points=\"250.278,-184.5 354.927,-184.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"302.603\" y=\"-169.3\">(None, 33, 128)</text>\n",
"</g>\n",
"<!-- 4464668856&#45;&gt;4464668800 -->\n",
"<g class=\"edge\" id=\"edge1\"><title>4464668856-&gt;4464668800</title>\n",
"<path d=\"M189.98,-243.329C189.98,-235.183 189.98,-225.699 189.98,-216.797\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"193.48,-216.729 189.98,-206.729 186.48,-216.729 193.48,-216.729\" stroke=\"black\"/>\n",
"</g>\n",
"<!-- 4595134984 -->\n",
"<g class=\"node\" id=\"node3\"><title>4595134984</title>\n",
"<polygon fill=\"none\" points=\"64.2988,-81.5 64.2988,-125.5 315.661,-125.5 315.661,-81.5 64.2988,-81.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"109.821\" y=\"-99.3\">gru_2 (GRU)</text>\n",
"<polyline fill=\"none\" points=\"155.344,-81.5 155.344,-125.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"183.178\" y=\"-110.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"155.344,-103.5 211.013,-103.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"183.178\" y=\"-88.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"211.013,-81.5 211.013,-125.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"263.337\" y=\"-110.3\">(None, 33, 128)</text>\n",
"<polyline fill=\"none\" points=\"211.013,-103.5 315.661,-103.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"263.337\" y=\"-88.3\">(None, 33, 128)</text>\n",
"</g>\n",
"<!-- 4464668800&#45;&gt;4595134984 -->\n",
"<g class=\"edge\" id=\"edge2\"><title>4464668800-&gt;4595134984</title>\n",
"<path d=\"M189.98,-162.329C189.98,-154.183 189.98,-144.699 189.98,-135.797\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"193.48,-135.729 189.98,-125.729 186.48,-135.729 193.48,-135.729\" stroke=\"black\"/>\n",
"</g>\n",
"<!-- 4464240680 -->\n",
"<g class=\"node\" id=\"node4\"><title>4464240680</title>\n",
"<polygon fill=\"none\" points=\"0,-0.5 0,-44.5 379.96,-44.5 379.96,-0.5 0,-0.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"109.821\" y=\"-18.3\">timedistributed_1 (TimeDistributed)</text>\n",
"<polyline fill=\"none\" points=\"219.643,-0.5 219.643,-44.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"247.477\" y=\"-29.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"219.643,-22.5 275.312,-22.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"247.477\" y=\"-7.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"275.312,-0.5 275.312,-44.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"327.636\" y=\"-29.3\">(None, 33, 128)</text>\n",
"<polyline fill=\"none\" points=\"275.312,-22.5 379.96,-22.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"327.636\" y=\"-7.3\">(None, 33, 128)</text>\n",
"</g>\n",
"<!-- 4595134984&#45;&gt;4464240680 -->\n",
"<g class=\"edge\" id=\"edge3\"><title>4595134984-&gt;4464240680</title>\n",
"<path d=\"M189.98,-81.3294C189.98,-73.1826 189.98,-63.6991 189.98,-54.7971\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"193.48,-54.729 189.98,-44.729 186.48,-54.729 193.48,-54.729\" stroke=\"black\"/>\n",
"</g>\n",
"</g>\n",
"</svg>"
],
"text/plain": [
"<IPython.core.display.SVG object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"merge and decoder output\n"
]
},
{
"data": {
"image/svg+xml": [
"<svg height=\"377pt\" viewBox=\"0.00 0.00 664.00 377.00\" width=\"664pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 373)\">\n",
"<title>G</title>\n",
"<polygon fill=\"white\" points=\"-4,4 -4,-373 659.997,-373 659.997,4 -4,4\" stroke=\"none\"/>\n",
"<!-- 4591282664 -->\n",
"<g class=\"node\" id=\"node1\"><title>4591282664</title>\n",
"<polygon fill=\"none\" points=\"0,-324.5 0,-368.5 318.997,-368.5 318.997,-324.5 0,-324.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"79.3398\" y=\"-342.3\">sequential_1 (Sequential)</text>\n",
"<polyline fill=\"none\" points=\"158.68,-324.5 158.68,-368.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"186.514\" y=\"-353.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"158.68,-346.5 214.349,-346.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"186.514\" y=\"-331.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"214.349,-324.5 214.349,-368.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"266.673\" y=\"-353.3\">(None, 56)</text>\n",
"<polyline fill=\"none\" points=\"214.349,-346.5 318.997,-346.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"266.673\" y=\"-331.3\">(None, 33, 128)</text>\n",
"</g>\n",
"<!-- 4464264640 -->\n",
"<g class=\"node\" id=\"node3\"><title>4464264640</title>\n",
"<polygon fill=\"none\" points=\"137.929,-243.5 137.929,-287.5 517.068,-287.5 517.068,-243.5 137.929,-243.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"194.854\" y=\"-261.3\">merge_1 (Merge)</text>\n",
"<polyline fill=\"none\" points=\"251.778,-243.5 251.778,-287.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"279.613\" y=\"-272.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"251.778,-265.5 307.447,-265.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"279.613\" y=\"-250.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"307.447,-243.5 307.447,-287.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"412.258\" y=\"-272.3\">[(None, 33, 128), (None, 33, 128)]</text>\n",
"<polyline fill=\"none\" points=\"307.447,-265.5 517.068,-265.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"411.771\" y=\"-250.3\">(None, 33, 256)</text>\n",
"</g>\n",
"<!-- 4591282664&#45;&gt;4464264640 -->\n",
"<g class=\"edge\" id=\"edge1\"><title>4591282664-&gt;4464264640</title>\n",
"<path d=\"M204.519,-324.329C225.696,-314.372 251.119,-302.417 273.339,-291.968\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"275.085,-295.015 282.645,-287.592 272.106,-288.68 275.085,-295.015\" stroke=\"black\"/>\n",
"</g>\n",
"<!-- 4595134928 -->\n",
"<g class=\"node\" id=\"node2\"><title>4595134928</title>\n",
"<polygon fill=\"none\" points=\"337,-324.5 337,-368.5 655.997,-368.5 655.997,-324.5 337,-324.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"416.34\" y=\"-342.3\">sequential_2 (Sequential)</text>\n",
"<polyline fill=\"none\" points=\"495.68,-324.5 495.68,-368.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"523.514\" y=\"-353.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"495.68,-346.5 551.349,-346.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"523.514\" y=\"-331.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"551.349,-324.5 551.349,-368.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"603.673\" y=\"-353.3\">(None, 33)</text>\n",
"<polyline fill=\"none\" points=\"551.349,-346.5 655.997,-346.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"603.673\" y=\"-331.3\">(None, 33, 128)</text>\n",
"</g>\n",
"<!-- 4595134928&#45;&gt;4464264640 -->\n",
"<g class=\"edge\" id=\"edge2\"><title>4595134928-&gt;4464264640</title>\n",
"<path d=\"M451.21,-324.329C429.812,-314.327 404.105,-302.31 381.681,-291.828\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"383.16,-288.656 372.619,-287.592 380.196,-294.997 383.16,-288.656\" stroke=\"black\"/>\n",
"</g>\n",
"<!-- 4464264696 -->\n",
"<g class=\"node\" id=\"node4\"><title>4464264696</title>\n",
"<polygon fill=\"none\" points=\"201.817,-162.5 201.817,-206.5 453.18,-206.5 453.18,-162.5 201.817,-162.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"247.34\" y=\"-180.3\">gru_3 (GRU)</text>\n",
"<polyline fill=\"none\" points=\"292.862,-162.5 292.862,-206.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"320.697\" y=\"-191.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"292.862,-184.5 348.531,-184.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"320.697\" y=\"-169.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"348.531,-162.5 348.531,-206.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"400.855\" y=\"-191.3\">(None, 33, 256)</text>\n",
"<polyline fill=\"none\" points=\"348.531,-184.5 453.18,-184.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"400.855\" y=\"-169.3\">(None, 128)</text>\n",
"</g>\n",
"<!-- 4464264640&#45;&gt;4464264696 -->\n",
"<g class=\"edge\" id=\"edge3\"><title>4464264640-&gt;4464264696</title>\n",
"<path d=\"M327.499,-243.329C327.499,-235.183 327.499,-225.699 327.499,-216.797\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"330.999,-216.729 327.499,-206.729 323.999,-216.729 330.999,-216.729\" stroke=\"black\"/>\n",
"</g>\n",
"<!-- 4464265312 -->\n",
"<g class=\"node\" id=\"node5\"><title>4464265312</title>\n",
"<polygon fill=\"none\" points=\"202.997,-81.5 202.997,-125.5 452,-125.5 452,-81.5 202.997,-81.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"257.84\" y=\"-99.3\">dense_2 (Dense)</text>\n",
"<polyline fill=\"none\" points=\"312.683,-81.5 312.683,-125.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"340.518\" y=\"-110.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"312.683,-103.5 368.352,-103.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"340.518\" y=\"-88.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"368.352,-81.5 368.352,-125.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"410.176\" y=\"-110.3\">(None, 128)</text>\n",
"<polyline fill=\"none\" points=\"368.352,-103.5 452,-103.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"410.176\" y=\"-88.3\">(None, 132)</text>\n",
"</g>\n",
"<!-- 4464264696&#45;&gt;4464265312 -->\n",
"<g class=\"edge\" id=\"edge4\"><title>4464264696-&gt;4464265312</title>\n",
"<path d=\"M327.499,-162.329C327.499,-154.183 327.499,-144.699 327.499,-135.797\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"330.999,-135.729 327.499,-125.729 323.999,-135.729 330.999,-135.729\" stroke=\"black\"/>\n",
"</g>\n",
"<!-- 4595023824 -->\n",
"<g class=\"node\" id=\"node6\"><title>4595023824</title>\n",
"<polygon fill=\"none\" points=\"179.279,-0.5 179.279,-44.5 475.718,-44.5 475.718,-0.5 179.279,-0.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"257.84\" y=\"-18.3\">activation_1 (Activation)</text>\n",
"<polyline fill=\"none\" points=\"336.4,-0.5 336.4,-44.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"364.235\" y=\"-29.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"336.4,-22.5 392.069,-22.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"364.235\" y=\"-7.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"392.069,-0.5 392.069,-44.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"433.894\" y=\"-29.3\">(None, 132)</text>\n",
"<polyline fill=\"none\" points=\"392.069,-22.5 475.718,-22.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"433.894\" y=\"-7.3\">(None, 132)</text>\n",
"</g>\n",
"<!-- 4464265312&#45;&gt;4595023824 -->\n",
"<g class=\"edge\" id=\"edge5\"><title>4464265312-&gt;4595023824</title>\n",
"<path d=\"M327.499,-81.3294C327.499,-73.1826 327.499,-63.6991 327.499,-54.7971\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"330.999,-54.729 327.499,-44.729 323.999,-54.729 330.999,-54.729\" stroke=\"black\"/>\n",
"</g>\n",
"</g>\n",
"</svg>"
],
"text/plain": [
"<IPython.core.display.SVG object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"print(\"encoder\")\n",
"display(SVG(model_to_dot(encoder, show_shapes=True).create(prog='dot', format='svg')))\n",
"print(\"decoder input\")\n",
"display(SVG(model_to_dot(decoder_input, show_shapes=True).create(prog='dot', format='svg')))\n",
"print(\"merge and decoder output\")\n",
"display(SVG(model_to_dot(model, show_shapes=True).create(prog='dot', format='svg')))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 loss 4.88168478012\n",
"[WARNING] my_model_weights_0.h5 already exists - overwrite? [y/n]y\n",
"[TIP] Next time specify overwrite=True in save_weights!\n",
"1 loss 4.85395431519\n",
"[WARNING] my_model_weights_1.h5 already exists - overwrite? [y/n]y\n",
"[TIP] Next time specify overwrite=True in save_weights!\n",
"2 loss 4.82331466675\n",
"3 loss 4.78387212753\n",
"4 loss 4.72859144211\n",
"5 loss 4.65308618546\n",
"6 loss 4.57625722885\n",
"7 loss 4.53256464005\n",
"8 loss 4.50397205353\n",
"9 loss 4.48307847977\n"
]
}
],
"source": [
"# training\n",
"for i in range(10):\n",
" loss = 0.\n",
" for x, y in gen_training_data(ja_docs, en_docs, encoder_maxlen, decoder_maxlen, decoder_vocab_size, 256):\n",
" loss += model.train_on_batch(x, y)\n",
" print(i, \"loss \", loss)\n",
" model.save_weights('my_model_weights_' + str(i) + '.h5')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# save model without weights\n",
"with open('my_model.json', 'w') as f:\n",
" json.dump(model.to_json(), f)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# load model\n",
"with open(\"my_model.json\") as f:\n",
" model = json.load(f)\n",
"\n",
"model = model_from_json(model)\n",
"\n",
"\n",
"# load and set weights\n",
"model.load_weights(\"./my_model_weights_9.h5\")\n",
"model.compile(loss='categorical_crossentropy', optimizer='adadelta')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def sample(p):\n",
" p /= sum(p)\n",
" return np.where(np.random.multinomial(1,p,1)==1)[1][0]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"shall over turning before broke lane she i lane very came kissed he until door away the to should that \n",
"you that candle to lifted long broke i think ever comfortably night think good more word as lifted silence deserves ask the but then not talk he norbury was me if i towards \n",
"to you given broke one than we very getting say should know obliged grant another one \n",
"comfortably think he than you ear you you it lifted and and better think not door which am we out home one now what kissed kindly one me broke one that his ten use my kindly wife sleeve us one for he kindly still in obliged followed the \n",
"turning i came talk shall very silence and kissed my more hands the said to when broke it am me man \n",
"towards it i child should lane broke and what that to comfortably silence a a what turned them powers which an be another what came if her in become effie she kindly and man over you before you said when a \n",
"powers said and down down us i am overconfident came in candle effie an think think deserves she as am clasped sleeve turning holmes his not with then more getting came getting clasped other came little broke when a candle lane being think the towards think more you other answer i which a which answer\n"
]
}
],
"source": [
"encoder_words = \"そして グラント マン ロー は 沈黙 を 破っ た\"\n",
"encoder_in = sequence.pad_sequences(ja_tokenizer.texts_to_sequences([encoder_words]), maxlen=encoder_maxlen)\n",
"np.random.seed(13)\n",
"\n",
"res = set()\n",
"for i in range(100):\n",
" decoder_words = \"gos\"\n",
" for _ in range(encoder_maxlen):\n",
" decoder_in = sequence.pad_sequences(en_tokenizer.texts_to_sequences([decoder_words]), maxlen=decoder_maxlen)\n",
" words_p = model.predict([encoder_in, decoder_in])[0]\n",
"\n",
" wordid = sample(words_p)\n",
" for k, v in en_tokenizer.word_index.items():\n",
" if v == wordid:\n",
" decoder_words += \" \" + k\n",
" break\n",
" if k == \"eos\":\n",
" break\n",
" if len(decoder_words.split()) != 2:\n",
" res.add(decoder_words.replace(\"gos \", \"\").replace(\"eos\", \"\"))\n",
"\n",
"for sentence in res:\n",
" if \"broke\" in sentence:\n",
" print(sentence)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment