Last active
December 13, 2022 16:51
-
-
Save madrugado/63c068b52a135c6fdbbb6fe17acbc0c8 to your computer and use it in GitHub Desktop.
Keras usage example, simple text classification
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"This notebook is based on [this example from Francois Chollet](https://github.com/fchollet/keras/blob/master/examples/reuters_mlp.py)." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Train and evaluate a simple MLP on the 20 newsgroups topic classification task." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using TensorFlow backend.\n" | |
] | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"import keras\n", | |
"from keras.models import Sequential, Model\n", | |
"from keras.layers import Dense, Dropout, Activation, Input\n", | |
"from keras.preprocessing.text import Tokenizer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"max_words = 1000\n", | |
"batch_size = 32\n", | |
"epochs = 5" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.datasets import fetch_20newsgroups\n", | |
"newsgroups_train = fetch_20newsgroups(subset='train')\n", | |
"newsgroups_test = fetch_20newsgroups(subset='test')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"From: [email protected] (where's my thing)\n", | |
"Subject: WHAT car is this!?\n", | |
"Nntp-Posting-Host: rac3.wam.umd.edu\n", | |
"Organization: University of Maryland, College Park\n", | |
"Lines: 15\n", | |
"\n", | |
" I was wondering if anyone out there could enlighten me on this car I saw\n", | |
"the other day. It was a 2-door sports car, looked to be from the late 60s/\n", | |
"early 70s. It was called a Bricklin. The doors were really small. In addition,\n", | |
"the front bumper was separate from the rest of the body. This is \n", | |
"all I know. If anyone can tellme a model name, engine specs, years\n", | |
"of production, where this car is made, history, or whatever info you\n", | |
"have on this funky looking car, please e-mail.\n", | |
"\n", | |
"Thanks,\n", | |
"- IL\n", | |
" ---- brought to you by your neighborhood Lerxst ----\n", | |
"\n", | |
"\n", | |
"\n", | |
"\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"print(newsgroups_train[\"data\"][0])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Preparing the Tokenizer...\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"Preparing the Tokenizer...\")\n", | |
"tokenizer = Tokenizer(num_words=max_words)\n", | |
"tokenizer.fit_on_texts(newsgroups_train[\"data\"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Vectorizing sequence data...\n", | |
"x_train shape: (11314, 1000)\n", | |
"x_test shape: (7532, 1000)\n" | |
] | |
} | |
], | |
"source": [ | |
"print('Vectorizing sequence data...')\n", | |
"x_train = tokenizer.texts_to_matrix(newsgroups_train[\"data\"], mode='binary')\n", | |
"x_test = tokenizer.texts_to_matrix(newsgroups_test[\"data\"], mode='binary')\n", | |
"print('x_train shape:', x_train.shape)\n", | |
"print('x_test shape:', x_test.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0., 1., 1., ..., 0., 0., 0.],\n", | |
" [ 0., 1., 1., ..., 0., 0., 0.],\n", | |
" [ 0., 1., 1., ..., 1., 0., 0.],\n", | |
" ..., \n", | |
" [ 0., 1., 1., ..., 0., 0., 0.],\n", | |
" [ 0., 1., 1., ..., 0., 0., 0.],\n", | |
" [ 0., 0., 0., ..., 0., 0., 0.]])" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x_train" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"20 classes\n" | |
] | |
} | |
], | |
"source": [ | |
"num_classes = np.max(newsgroups_train[\"target\"]) + 1\n", | |
"print(num_classes, 'classes')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Convert class vector to binary class matrix (for use with categorical_crossentropy)\n", | |
"y_train shape: (11314, 20)\n", | |
"y_test shape: (7532, 20)\n" | |
] | |
} | |
], | |
"source": [ | |
"print('Convert class vector to binary class matrix '\n", | |
" '(for use with categorical_crossentropy)')\n", | |
"y_train = keras.utils.to_categorical(newsgroups_train[\"target\"], num_classes)\n", | |
"y_test = keras.utils.to_categorical(newsgroups_test[\"target\"], num_classes)\n", | |
"print('y_train shape:', y_train.shape)\n", | |
"print('y_test shape:', y_test.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0., 0., 0., ..., 0., 0., 0.],\n", | |
" [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
" [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
" ..., \n", | |
" [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
" [ 0., 1., 0., ..., 0., 0., 0.],\n", | |
" [ 0., 0., 0., ..., 0., 0., 0.]])" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_train" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Building model sequentially 1...\n" | |
] | |
} | |
], | |
"source": [ | |
"print('Building model sequentially 1...')\n", | |
"model = Sequential()\n", | |
"model.add(Dense(512, input_shape=(max_words,)))\n", | |
"model.add(Activation('relu'))\n", | |
"model.add(Dropout(0.5))\n", | |
"model.add(Dense(num_classes))\n", | |
"model.add(Activation('softmax'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Building model sequentially 2...\n" | |
] | |
} | |
], | |
"source": [ | |
"print('Building model sequentially 2...')\n", | |
"model = Sequential([\n", | |
" Dense(512, input_shape=(max_words,)),\n", | |
" Activation('relu'),\n", | |
" Dropout(0.5),\n", | |
" Dense(num_classes),\n", | |
" Activation('softmax')\n", | |
" ])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<keras.layers.core.Dense at 0x1123c4b00>,\n", | |
" <keras.layers.core.Activation at 0x1122db780>,\n", | |
" <keras.layers.core.Dropout at 0x1122db940>,\n", | |
" <keras.layers.core.Dense at 0x1122dbe10>,\n", | |
" <keras.layers.core.Activation at 0x112325390>]" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.layers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"backend: tensorflow\n", | |
"class_name: Sequential\n", | |
"config:\n", | |
"- class_name: Dense\n", | |
" config:\n", | |
" activation: linear\n", | |
" activity_regularizer: null\n", | |
" batch_input_shape: !!python/tuple [null, 1000]\n", | |
" bias_constraint: null\n", | |
" bias_initializer:\n", | |
" class_name: Zeros\n", | |
" config: {}\n", | |
" bias_regularizer: null\n", | |
" dtype: float32\n", | |
" kernel_constraint: null\n", | |
" kernel_initializer:\n", | |
" class_name: VarianceScaling\n", | |
" config: {distribution: uniform, mode: fan_avg, scale: 1.0, seed: null}\n", | |
" kernel_regularizer: null\n", | |
" name: dense_3\n", | |
" trainable: true\n", | |
" units: 512\n", | |
" use_bias: true\n", | |
"- class_name: Activation\n", | |
" config: {activation: relu, name: activation_3, trainable: true}\n", | |
"- class_name: Dropout\n", | |
" config: {name: dropout_2, rate: 0.5, trainable: true}\n", | |
"- class_name: Dense\n", | |
" config:\n", | |
" activation: linear\n", | |
" activity_regularizer: null\n", | |
" bias_constraint: null\n", | |
" bias_initializer:\n", | |
" class_name: Zeros\n", | |
" config: {}\n", | |
" bias_regularizer: null\n", | |
" kernel_constraint: null\n", | |
" kernel_initializer:\n", | |
" class_name: VarianceScaling\n", | |
" config: {distribution: uniform, mode: fan_avg, scale: 1.0, seed: null}\n", | |
" kernel_regularizer: null\n", | |
" name: dense_4\n", | |
" trainable: true\n", | |
" units: !!python/object/apply:numpy.core.multiarray.scalar\n", | |
" - !!python/object/apply:numpy.dtype\n", | |
" args: [i8, 0, 1]\n", | |
" state: !!python/tuple [3, <, null, null, null, -1, -1, 0]\n", | |
" - !!binary |\n", | |
" FAAAAAAAAAA=\n", | |
" use_bias: true\n", | |
"- class_name: Activation\n", | |
" config: {activation: softmax, name: activation_4, trainable: true}\n", | |
"keras_version: 2.0.2\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"print(model.to_yaml())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Building model functionally...\n" | |
] | |
} | |
], | |
"source": [ | |
"print('Building model functionally...')\n", | |
"a = Input(shape=(max_words,))\n", | |
"b = Dense(512)(a)\n", | |
"b = Activation('relu')(b)\n", | |
"b = Dropout(0.5)(b)\n", | |
"b = Dense(num_classes)(b)\n", | |
"b = Activation('softmax')(b)\n", | |
"model = Model(inputs=a, outputs=b)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from keras.models import model_from_yaml\n", | |
"\n", | |
"yaml_string = model.to_yaml()\n", | |
"model = model_from_yaml(yaml_string)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied (use --upgrade to upgrade): pydot-ng in /Users/madrugado/anaconda3/lib/python3.5/site-packages\n", | |
"Requirement already satisfied (use --upgrade to upgrade): pyparsing>=2.0.1 in /Users/madrugado/anaconda3/lib/python3.5/site-packages (from pydot-ng)\n", | |
"\u001b[33mYou are using pip version 8.1.2, however version 9.0.1 is available.\n", | |
"You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n" | |
] | |
} | |
], | |
"source": [ | |
"! pip install pydot-ng" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from keras.utils import plot_model\n", | |
"plot_model(model, to_file='model.png', show_shapes=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/svg+xml": [ | |
"<svg height=\"458pt\" viewBox=\"0.00 0.00 298.24 458.00\" width=\"298pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
"<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 454)\">\n", | |
"<title>G</title>\n", | |
"<polygon fill=\"white\" points=\"-4,4 -4,-454 294.238,-454 294.238,4 -4,4\" stroke=\"none\"/>\n", | |
"<!-- 4600907424 -->\n", | |
"<g class=\"node\" id=\"node1\"><title>4600907424</title>\n", | |
"<polygon fill=\"none\" points=\"7.7793,-405.5 7.7793,-449.5 282.459,-449.5 282.459,-405.5 7.7793,-405.5\" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"71.9604\" y=\"-423.3\">input_1: InputLayer</text>\n", | |
"<polyline fill=\"none\" points=\"136.142,-405.5 136.142,-449.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"163.976\" y=\"-434.3\">input:</text>\n", | |
"<polyline fill=\"none\" points=\"136.142,-427.5 191.811,-427.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"163.976\" y=\"-412.3\">output:</text>\n", | |
"<polyline fill=\"none\" points=\"191.811,-405.5 191.811,-449.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"237.135\" y=\"-434.3\">(None, 1000)</text>\n", | |
"<polyline fill=\"none\" points=\"191.811,-427.5 282.459,-427.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"237.135\" y=\"-412.3\">(None, 1000)</text>\n", | |
"</g>\n", | |
"<!-- 4600907312 -->\n", | |
"<g class=\"node\" id=\"node2\"><title>4600907312</title>\n", | |
"<polygon fill=\"none\" points=\"19.8345,-324.5 19.8345,-368.5 270.404,-368.5 270.404,-324.5 19.8345,-324.5\" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"71.9604\" y=\"-342.3\">dense_5: Dense</text>\n", | |
"<polyline fill=\"none\" points=\"124.086,-324.5 124.086,-368.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"151.921\" y=\"-353.3\">input:</text>\n", | |
"<polyline fill=\"none\" points=\"124.086,-346.5 179.755,-346.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"151.921\" y=\"-331.3\">output:</text>\n", | |
"<polyline fill=\"none\" points=\"179.755,-324.5 179.755,-368.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"225.08\" y=\"-353.3\">(None, 1000)</text>\n", | |
"<polyline fill=\"none\" points=\"179.755,-346.5 270.404,-346.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"225.08\" y=\"-331.3\">(None, 512)</text>\n", | |
"</g>\n", | |
"<!-- 4600907424->4600907312 -->\n", | |
"<g class=\"edge\" id=\"edge1\"><title>4600907424->4600907312</title>\n", | |
"<path d=\"M145.119,-405.329C145.119,-397.183 145.119,-387.699 145.119,-378.797\" fill=\"none\" stroke=\"black\"/>\n", | |
"<polygon fill=\"black\" points=\"148.619,-378.729 145.119,-368.729 141.619,-378.729 148.619,-378.729\" stroke=\"black\"/>\n", | |
"</g>\n", | |
"<!-- 4601527376 -->\n", | |
"<g class=\"node\" id=\"node3\"><title>4601527376</title>\n", | |
"<polygon fill=\"none\" points=\"0,-243.5 0,-287.5 290.238,-287.5 290.238,-243.5 0,-243.5\" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"75.4604\" y=\"-261.3\">activation_5: Activation</text>\n", | |
"<polyline fill=\"none\" points=\"150.921,-243.5 150.921,-287.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"178.755\" y=\"-272.3\">input:</text>\n", | |
"<polyline fill=\"none\" points=\"150.921,-265.5 206.59,-265.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"178.755\" y=\"-250.3\">output:</text>\n", | |
"<polyline fill=\"none\" points=\"206.59,-243.5 206.59,-287.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"248.414\" y=\"-272.3\">(None, 512)</text>\n", | |
"<polyline fill=\"none\" points=\"206.59,-265.5 290.238,-265.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"248.414\" y=\"-250.3\">(None, 512)</text>\n", | |
"</g>\n", | |
"<!-- 4600907312->4601527376 -->\n", | |
"<g class=\"edge\" id=\"edge2\"><title>4600907312->4601527376</title>\n", | |
"<path d=\"M145.119,-324.329C145.119,-316.183 145.119,-306.699 145.119,-297.797\" fill=\"none\" stroke=\"black\"/>\n", | |
"<polygon fill=\"black\" points=\"148.619,-297.729 145.119,-287.729 141.619,-297.729 148.619,-297.729\" stroke=\"black\"/>\n", | |
"</g>\n", | |
"<!-- 4601040736 -->\n", | |
"<g class=\"node\" id=\"node4\"><title>4601040736</title>\n", | |
"<polygon fill=\"none\" points=\"11.6587,-162.5 11.6587,-206.5 278.58,-206.5 278.58,-162.5 11.6587,-162.5\" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"75.4604\" y=\"-180.3\">dropout_3: Dropout</text>\n", | |
"<polyline fill=\"none\" points=\"139.262,-162.5 139.262,-206.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"167.097\" y=\"-191.3\">input:</text>\n", | |
"<polyline fill=\"none\" points=\"139.262,-184.5 194.931,-184.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"167.097\" y=\"-169.3\">output:</text>\n", | |
"<polyline fill=\"none\" points=\"194.931,-162.5 194.931,-206.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"236.755\" y=\"-191.3\">(None, 512)</text>\n", | |
"<polyline fill=\"none\" points=\"194.931,-184.5 278.58,-184.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"236.755\" y=\"-169.3\">(None, 512)</text>\n", | |
"</g>\n", | |
"<!-- 4601527376->4601040736 -->\n", | |
"<g class=\"edge\" id=\"edge3\"><title>4601527376->4601040736</title>\n", | |
"<path d=\"M145.119,-243.329C145.119,-235.183 145.119,-225.699 145.119,-216.797\" fill=\"none\" stroke=\"black\"/>\n", | |
"<polygon fill=\"black\" points=\"148.619,-216.729 145.119,-206.729 141.619,-216.729 148.619,-216.729\" stroke=\"black\"/>\n", | |
"</g>\n", | |
"<!-- 4600579912 -->\n", | |
"<g class=\"node\" id=\"node5\"><title>4600579912</title>\n", | |
"<polygon fill=\"none\" points=\"23.3345,-81.5 23.3345,-125.5 266.904,-125.5 266.904,-81.5 23.3345,-81.5\" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"75.4604\" y=\"-99.3\">dense_6: Dense</text>\n", | |
"<polyline fill=\"none\" points=\"127.586,-81.5 127.586,-125.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"155.421\" y=\"-110.3\">input:</text>\n", | |
"<polyline fill=\"none\" points=\"127.586,-103.5 183.255,-103.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"155.421\" y=\"-88.3\">output:</text>\n", | |
"<polyline fill=\"none\" points=\"183.255,-81.5 183.255,-125.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"225.08\" y=\"-110.3\">(None, 512)</text>\n", | |
"<polyline fill=\"none\" points=\"183.255,-103.5 266.904,-103.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"225.08\" y=\"-88.3\">(None, 20)</text>\n", | |
"</g>\n", | |
"<!-- 4601040736->4600579912 -->\n", | |
"<g class=\"edge\" id=\"edge4\"><title>4601040736->4600579912</title>\n", | |
"<path d=\"M145.119,-162.329C145.119,-154.183 145.119,-144.699 145.119,-135.797\" fill=\"none\" stroke=\"black\"/>\n", | |
"<polygon fill=\"black\" points=\"148.619,-135.729 145.119,-125.729 141.619,-135.729 148.619,-135.729\" stroke=\"black\"/>\n", | |
"</g>\n", | |
"<!-- 4601423128 -->\n", | |
"<g class=\"node\" id=\"node6\"><title>4601423128</title>\n", | |
"<polygon fill=\"none\" points=\"3.5,-0.5 3.5,-44.5 286.738,-44.5 286.738,-0.5 3.5,-0.5\" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"78.9604\" y=\"-18.3\">activation_6: Activation</text>\n", | |
"<polyline fill=\"none\" points=\"154.421,-0.5 154.421,-44.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"182.255\" y=\"-29.3\">input:</text>\n", | |
"<polyline fill=\"none\" points=\"154.421,-22.5 210.09,-22.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"182.255\" y=\"-7.3\">output:</text>\n", | |
"<polyline fill=\"none\" points=\"210.09,-0.5 210.09,-44.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"248.414\" y=\"-29.3\">(None, 20)</text>\n", | |
"<polyline fill=\"none\" points=\"210.09,-22.5 286.738,-22.5 \" stroke=\"black\"/>\n", | |
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"248.414\" y=\"-7.3\">(None, 20)</text>\n", | |
"</g>\n", | |
"<!-- 4600579912->4601423128 -->\n", | |
"<g class=\"edge\" id=\"edge5\"><title>4600579912->4601423128</title>\n", | |
"<path d=\"M145.119,-81.3294C145.119,-73.1826 145.119,-63.6991 145.119,-54.7971\" fill=\"none\" stroke=\"black\"/>\n", | |
"<polygon fill=\"black\" points=\"148.619,-54.729 145.119,-44.729 141.619,-54.729 148.619,-54.729\" stroke=\"black\"/>\n", | |
"</g>\n", | |
"</g>\n", | |
"</svg>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.SVG object>" | |
] | |
}, | |
"execution_count": 28, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from IPython.display import SVG\n", | |
"from keras.utils.vis_utils import model_to_dot\n", | |
"\n", | |
"SVG(model_to_dot(model, show_shapes=True).create(prog='dot', format='svg'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from keras.objectives import categorical_crossentropy\n", | |
"from keras import backend as K\n", | |
"\n", | |
"epsilon = 1.0e-9\n", | |
"def custom_objective(y_true, y_pred):\n", | |
" '''Yet another crossentropy'''\n", | |
" y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)\n", | |
" y_pred /= K.sum(y_pred, axis=-1, keepdims=True)\n", | |
" cce = categorical_crossentropy(y_pred, y_true)\n", | |
" return cce" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"model.compile(loss='categorical_crossentropy',\n", | |
" optimizer='adam',\n", | |
" metrics=['accuracy'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"model.compile(loss=custom_objective,\n", | |
" optimizer='adam',\n", | |
" metrics=['accuracy'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Train on 10182 samples, validate on 1132 samples\n", | |
"Epoch 1/5\n", | |
"10182/10182 [==============================] - 6s - loss: 11.2602 - acc: 0.3476 - val_loss: 8.1297 - val_acc: 0.5442\n", | |
"Epoch 2/5\n", | |
"10182/10182 [==============================] - 5s - loss: 7.1478 - acc: 0.5968 - val_loss: 6.9277 - val_acc: 0.5998\n", | |
"Epoch 3/5\n", | |
"10182/10182 [==============================] - 5s - loss: 5.6592 - acc: 0.6782 - val_loss: 5.8904 - val_acc: 0.6564\n", | |
"Epoch 4/5\n", | |
"10182/10182 [==============================] - 5s - loss: 4.8580 - acc: 0.7209 - val_loss: 5.7133 - val_acc: 0.6643\n", | |
"Epoch 5/5\n", | |
"10182/10182 [==============================] - 5s - loss: 4.5376 - acc: 0.7376 - val_loss: 5.5546 - val_acc: 0.6687\n" | |
] | |
} | |
], | |
"source": [ | |
"history = model.fit(x_train, y_train,\n", | |
" batch_size=batch_size,\n", | |
" epochs=epochs,\n", | |
" verbose=1,\n", | |
" validation_split=0.1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Train on 10182 samples, validate on 1132 samples\n", | |
"Epoch 1/5\n", | |
"10182/10182 [==============================] - 6s - loss: 5.0572 - acc: 0.6996 - val_loss: 5.9746 - val_acc: 0.6396\n", | |
"Epoch 2/5\n", | |
"10182/10182 [==============================] - 6s - loss: 4.8309 - acc: 0.7135 - val_loss: 5.9775 - val_acc: 0.6396\n" | |
] | |
} | |
], | |
"source": [ | |
"from keras.callbacks import EarlyStopping \n", | |
"early_stopping=EarlyStopping(monitor='val_loss') \n", | |
"\n", | |
"history = model.fit(x_train, y_train,\n", | |
" batch_size=batch_size,\n", | |
" epochs=epochs,\n", | |
" verbose=1,\n", | |
" validation_split=0.1,\n", | |
" callbacks=[early_stopping])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Train on 10182 samples, validate on 1132 samples\n", | |
"Epoch 1/5\n", | |
"10182/10182 [==============================] - 5s - loss: 4.6762 - acc: 0.7209 - val_loss: 5.8910 - val_acc: 0.6475\n", | |
"Epoch 2/5\n", | |
"10182/10182 [==============================] - 5s - loss: 4.5583 - acc: 0.7272 - val_loss: 5.8414 - val_acc: 0.6493\n", | |
"Epoch 3/5\n", | |
"10182/10182 [==============================] - 5s - loss: 4.4485 - acc: 0.7323 - val_loss: 5.9157 - val_acc: 0.6422\n", | |
"Epoch 4/5\n", | |
"10182/10182 [==============================] - 5s - loss: 4.3723 - acc: 0.7361 - val_loss: 5.9310 - val_acc: 0.6369\n", | |
"Epoch 5/5\n", | |
"10182/10182 [==============================] - 5s - loss: 4.3307 - acc: 0.7380 - val_loss: 5.7791 - val_acc: 0.6511\n" | |
] | |
} | |
], | |
"source": [ | |
"from keras.callbacks import TensorBoard \n", | |
"tensorboard=TensorBoard(log_dir='./logs', write_graph=True, write_images=True)\n", | |
"\n", | |
"history = model.fit(x_train, y_train,\n", | |
" batch_size=batch_size,\n", | |
" epochs=epochs,\n", | |
" verbose=1,\n", | |
" validation_split=0.1,\n", | |
" callbacks=[tensorboard])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"7328/7532 [============================>.] - ETA: 0s\n", | |
"\n", | |
"Test score: 7.05705657565\n", | |
"Test accuracy: 0.572623473149\n" | |
] | |
} | |
], | |
"source": [ | |
"score = model.evaluate(x_test, y_test,\n", | |
" batch_size=batch_size, verbose=1)\n", | |
"print('\\n')\n", | |
"print('Test score:', score[0])\n", | |
"print('Test accuracy:', score[1])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"anaconda-cloud": {}, | |
"kernelspec": { | |
"display_name": "Python [conda root]", | |
"language": "python", | |
"name": "conda-root-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
@bumsun, похоже, что проблема в том, что tokenizer.texts_to_matrix
принимает массив объектов. Если передаётся строка, то метод считает, что это массив символов и творит ерунду. У меня такой код работает, похоже, правильно:
prediction = model.predict(np.array(tokenizer.texts_to_matrix([text], mode='binary'))) # text заменил на [text]
print(prediction.shape) # (1, 20)
print(prediction) # массив из двадцати значений. i-ый элемент массива указывает вероятность того, что текст относится к i-ой категории.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Я новичок в машинном обучении. Подскажите пожалуйста, как сделать прогноз для данного сообщения? Потому что у меня что-то странное выдается, т.к. я скорее всего не правильно обрабатываю текст)