Created
November 19, 2015 14:23
-
-
Save xccds/8f0e5b0fe4eb6193261d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"collapsed": true | |
}, | |
"source": [ | |
"### 使用深度学习库keras做文本分类\n", | |
"- 数据是sogou的[语料库](http://www.sogou.com/labs/dl/c.html)\n", | |
"- 方法是卷积神经网络,可以参考kim的那篇文献\n", | |
"- 工具是keras库,它是基于theano构建的深度学习框架\n", | |
"- 问题是对sogou的新闻进行自动分类" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from os import path\n", | |
"import os\n", | |
"import re\n", | |
"import codecs\n", | |
"import pandas as pd\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"['SogouC.reduced/Reduced/C000022',\n", | |
" 'SogouC.reduced/Reduced/C000023',\n", | |
" 'SogouC.reduced/Reduced/C000016',\n", | |
" 'SogouC.reduced/Reduced/C000008',\n", | |
" 'SogouC.reduced/Reduced/C000024',\n", | |
" 'SogouC.reduced/Reduced/C000010',\n", | |
" 'SogouC.reduced/Reduced/C000013',\n", | |
" 'SogouC.reduced/Reduced/C000020',\n", | |
" 'SogouC.reduced/Reduced/C000014']" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"rootdir = 'SogouC.reduced/Reduced'\n", | |
"dirs = os.listdir(rootdir)\n", | |
"dirs = [path.join(rootdir,f) for f in dirs if f.startswith('C')]\n", | |
"dirs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"def load_txt(x):\n", | |
" with open(x) as f:\n", | |
" res = [t.decode('gbk','ignore') for t in f]\n", | |
" return ''.join(res)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"俄制“Ansat-LL”轻型试验用直升机\r\n", | |
" 俄罗斯lenta网站2006年5月2日报道 \r\n", | |
"梁赞直升机厂向俄罗斯海军交付了轻型试验用直升机“Ansat-LL”。海军将使用这种直升机进行各种武器装备的试验。俄海军总司令玛索林称,直升机将用于发展海军的武器装备。直升机对于继续发展俄海军武器装备具有十分重要的意义。轻型多用途直升机“Ansat”有几种型号,分为进攻型,运输型、客机型、医用型和训练型。于1994年开始设计。 \r\n", | |
"由梁赞直升机厂和“雷达”科研生产联合体共同研制。直升机最大飞行重量3.3吨,可在520千米和距离上运载1.3吨的有效负载,乘员为9人。\n" | |
] | |
} | |
], | |
"source": [ | |
"print load_txt('SogouC.reduced/Reduced/C000024/30.txt')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"text_t = {}\n", | |
"for i, d in enumerate(dirs):\n", | |
" files = os.listdir(d)\n", | |
" files = [path.join(d, x) for x in files if x.endswith('txt') and not x.startswith('.')]\n", | |
" text_t[i] = [load_txt(f) for f in files]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"flen = [len(t) for t in text_t.values()]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"labels = np.repeat(text_t.keys(),flen)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# flatter nested list\n", | |
"import itertools\n", | |
"merged = list(itertools.chain.from_iterable(text_t.values()))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>label</th>\n", | |
" <th>txt</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>记者: 刚刚结束的“office lady榜样”评选中,你被《瑞丽》评为“office ...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>0</td>\n", | |
" <td>本报讯(记者 王佳琳 通讯员 唐松寒) 从昨天开始,北京市4689家非连续生产型工业企业...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>0</td>\n", | |
" <td>第1页:如果你是透明人你会想做些什么事第2页:A你的野心很大第3页:B你自觉能力不错第4页:...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>0</td>\n", | |
" <td>面对应聘者迫切的求职心理和对高薪的渴望,一些企业打出了过激的招聘启事。\\r\\n 专家指...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>0</td>\n", | |
" <td>第1页:顶着压力办网站第2页:网上收废是发展方向\\r\\n 废品网站为居民解难\\r\\n 沸...</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" label txt\n", | |
"0 0 记者: 刚刚结束的“office lady榜样”评选中,你被《瑞丽》评为“office ...\n", | |
"1 0 本报讯(记者 王佳琳 通讯员 唐松寒) 从昨天开始,北京市4689家非连续生产型工业企业...\n", | |
"2 0 第1页:如果你是透明人你会想做些什么事第2页:A你的野心很大第3页:B你自觉能力不错第4页:...\n", | |
"3 0 面对应聘者迫切的求职心理和对高薪的渴望,一些企业打出了过激的招聘启事。\\r\\n 专家指...\n", | |
"4 0 第1页:顶着压力办网站第2页:网上收废是发展方向\\r\\n 废品网站为居民解难\\r\\n 沸..." | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df = pd.DataFrame({'label': labels, 'txt': merged})\n", | |
"df.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"df['ready_seg'] =df['txt'].str.replace(ur'\\W+', ' ',flags=re.U) # 非正常字符转空格\n", | |
"df['ready_seg'] =df['ready_seg'].str.replace(r'[A-Za-z]+', ' ENG ') # 英文转ENG\n", | |
"df['ready_seg'] =df['ready_seg'].str.replace(r'\\d+', ' NUM ') # 数字转NUM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# cut word\n", | |
"import jieba\n", | |
"def cutword_1(x):\n", | |
" words = jieba.cut(x)\n", | |
" return ' '.join(words)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 169, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"df['seg_word'] = df.ready_seg.map(cutword_1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 172, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>label</th>\n", | |
" <th>txt</th>\n", | |
" <th>ready_seg</th>\n", | |
" <th>seg_word</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>记者: 刚刚结束的“office lady榜样”评选中,你被《瑞丽》评为“office ...</td>\n", | |
" <td>记者 刚刚结束的 ENG ENG 榜样 评选中 你被 瑞丽 评为 ENG EN...</td>\n", | |
" <td>记者 刚刚 结束 的 ENG ENG 榜样 评选 中 ...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>0</td>\n", | |
" <td>本报讯(记者 王佳琳 通讯员 唐松寒) 从昨天开始,北京市4689家非连续生产型工业企业...</td>\n", | |
" <td>本报讯 记者 王佳琳 通讯员 唐松寒 从昨天开始 北京市 NUM 家非连续生产型工业企业 ...</td>\n", | |
" <td>本报讯 记者 王佳琳 通讯员 唐松寒 从 昨天 开始 北京市 ...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>0</td>\n", | |
" <td>第1页:如果你是透明人你会想做些什么事第2页:A你的野心很大第3页:B你自觉能力不错第4页:...</td>\n", | |
" <td>第 NUM 页 如果你是透明人你会想做些什么事第 NUM 页 ENG 你的野心很大第 NU...</td>\n", | |
" <td>第 NUM 页 如果 你 是 透明人 你 会 想 做些 什么 事 第 NUM...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>0</td>\n", | |
" <td>面对应聘者迫切的求职心理和对高薪的渴望,一些企业打出了过激的招聘启事。\\r\\n 专家指...</td>\n", | |
" <td>面对应聘者迫切的求职心理和对高薪的渴望 一些企业打出了过激的招聘启事 专家指出 民间统计 ...</td>\n", | |
" <td>面对 应聘者 迫切 的 求职 心理 和 对 高薪 的 渴望 一些 企业 打出 了 过...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>0</td>\n", | |
" <td>第1页:顶着压力办网站第2页:网上收废是发展方向\\r\\n 废品网站为居民解难\\r\\n 沸...</td>\n", | |
" <td>第 NUM 页 顶着压力办网站第 NUM 页 网上收废是发展方向 废品网站为居民解难 沸沸洋...</td>\n", | |
" <td>第 NUM 页 顶 着 压力 办 网站 第 NUM 页 网上 收废 ...</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" label txt \\\n", | |
"0 0 记者: 刚刚结束的“office lady榜样”评选中,你被《瑞丽》评为“office ... \n", | |
"1 0 本报讯(记者 王佳琳 通讯员 唐松寒) 从昨天开始,北京市4689家非连续生产型工业企业... \n", | |
"2 0 第1页:如果你是透明人你会想做些什么事第2页:A你的野心很大第3页:B你自觉能力不错第4页:... \n", | |
"3 0 面对应聘者迫切的求职心理和对高薪的渴望,一些企业打出了过激的招聘启事。\\r\\n 专家指... \n", | |
"4 0 第1页:顶着压力办网站第2页:网上收废是发展方向\\r\\n 废品网站为居民解难\\r\\n 沸... \n", | |
"\n", | |
" ready_seg \\\n", | |
"0 记者 刚刚结束的 ENG ENG 榜样 评选中 你被 瑞丽 评为 ENG EN... \n", | |
"1 本报讯 记者 王佳琳 通讯员 唐松寒 从昨天开始 北京市 NUM 家非连续生产型工业企业 ... \n", | |
"2 第 NUM 页 如果你是透明人你会想做些什么事第 NUM 页 ENG 你的野心很大第 NU... \n", | |
"3 面对应聘者迫切的求职心理和对高薪的渴望 一些企业打出了过激的招聘启事 专家指出 民间统计 ... \n", | |
"4 第 NUM 页 顶着压力办网站第 NUM 页 网上收废是发展方向 废品网站为居民解难 沸沸洋... \n", | |
"\n", | |
" seg_word \n", | |
"0 记者 刚刚 结束 的 ENG ENG 榜样 评选 中 ... \n", | |
"1 本报讯 记者 王佳琳 通讯员 唐松寒 从 昨天 开始 北京市 ... \n", | |
"2 第 NUM 页 如果 你 是 透明人 你 会 想 做些 什么 事 第 NUM... \n", | |
"3 面对 应聘者 迫切 的 求职 心理 和 对 高薪 的 渴望 一些 企业 打出 了 过... \n", | |
"4 第 NUM 页 顶 着 压力 办 网站 第 NUM 页 网上 收废 ... " | |
] | |
}, | |
"execution_count": 172, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# 文本整理完毕,后面建模需要将词汇转成数字编号,可以人工转,也可以让keras转" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 173, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"textraw = df.seg_word.values.tolist()\n", | |
"textraw = [line.encode('utf-8') for line in textraw] # 需要存为str才能被keras使用" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 273, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# keras处理token\n", | |
"maxfeatures = 50000 # 只选择最重要的词\n", | |
"from keras.preprocessing.text import Tokenizer\n", | |
"token = Tokenizer(nb_words=maxfeatures)\n", | |
"token.fit_on_texts(textraw) #如果文本较大可以使用文本流\n", | |
"text_seq = token.texts_to_sequences(textraw)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 175, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"#maxfeatures = len(token.word_counts)\n", | |
"#print maxfeatures # 语料库的词汇个数" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 264, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"400.0" | |
] | |
}, | |
"execution_count": 264, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.median([len(x) for x in text_seq]) # 每条新闻平均400个词汇" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 177, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"9\n" | |
] | |
} | |
], | |
"source": [ | |
"y = df.label.values # 定义好标签\n", | |
"nb_classes = len(np.unique(y))\n", | |
"print(nb_classes)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 316, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from __future__ import absolute_import\n", | |
"from keras.optimizers import RMSprop\n", | |
"from keras.preprocessing import sequence\n", | |
"from keras.models import Sequential\n", | |
"from keras.layers.core import Dense, Dropout, Activation, Flatten\n", | |
"from keras.layers.embeddings import Embedding\n", | |
"from keras.layers.convolutional import Convolution1D, MaxPooling1D\n", | |
"from keras.layers.recurrent import SimpleRNN, GRU, LSTM\n", | |
"from keras.callbacks import EarlyStopping" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 374, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"maxlen = 600 # 定义文本最大长度\n", | |
"batch_size = 32 # 批次\n", | |
"word_dim = 100 # 词向量维度\n", | |
"nb_filter = 200 # 卷积核个数\n", | |
"filter_length = 10 # 卷积窗口大小\n", | |
"hidden_dims = 50 # 隐藏层神经元个数\n", | |
"nb_epoch = 10 # 训练迭代次数\n", | |
"pool_length = 50 # 池化窗口大小" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 275, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.cross_validation import train_test_split\n", | |
"train_X, test_X, train_y, test_y = train_test_split(text_seq, y , train_size=0.8, random_state=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 276, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Pad sequences (samples x time)\n", | |
"('X_train shape:', (14328, 600))\n", | |
"('X_test shape:', (3582, 600))\n" | |
] | |
} | |
], | |
"source": [ | |
"# 转为等长矩阵,长度为maxlen\n", | |
"print(\"Pad sequences (samples x time)\")\n", | |
"X_train = sequence.pad_sequences(train_X, maxlen=maxlen,padding='post', truncating='post')\n", | |
"X_test = sequence.pad_sequences(test_X, maxlen=maxlen,padding='post', truncating='post')\n", | |
"print('X_train shape:', X_train.shape)\n", | |
"print('X_test shape:', X_test.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 277, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# 将y的格式展开成one-hot\n", | |
"Y_train = np_utils.to_categorical(train_y, nb_classes)\n", | |
"Y_test = np_utils.to_categorical(test_y, nb_classes)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 375, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Build model...\n" | |
] | |
} | |
], | |
"source": [ | |
"# CNN 模型\n", | |
"print('Build model...')\n", | |
"model = Sequential()\n", | |
"\n", | |
"# 词向量嵌入层,输入:词典大小,词向量大小,文本长度\n", | |
"model.add(Embedding(maxfeatures, word_dim,input_length=maxlen)) \n", | |
"model.add(Dropout(0.25))\n", | |
"model.add(Convolution1D(nb_filter=nb_filter,\n", | |
" filter_length=filter_length,\n", | |
" border_mode=\"valid\",\n", | |
" activation=\"relu\"))\n", | |
"# 池化层\n", | |
"model.add(MaxPooling1D(pool_length=pool_length))\n", | |
"model.add(Flatten())\n", | |
"# 全连接层\n", | |
"model.add(Dense(hidden_dims))\n", | |
"model.add(Dropout(0.25))\n", | |
"model.add(Activation('relu'))\n", | |
"model.add(Dense(nb_classes))\n", | |
"model.add(Activation('softmax'))\n", | |
"model.compile(loss='categorical_crossentropy', optimizer='rmsprop')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 376, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Train on 12895 samples, validate on 1433 samples\n", | |
"Epoch 1/10\n", | |
"12895/12895 [==============================] - 17s - loss: 2.0669 - acc: 0.1978 - val_loss: 1.7813 - val_acc: 0.3433\n", | |
"Epoch 2/10\n", | |
"12895/12895 [==============================] - 17s - loss: 1.3790 - acc: 0.4902 - val_loss: 0.8512 - val_acc: 0.7551\n", | |
"Epoch 3/10\n", | |
"12895/12895 [==============================] - 17s - loss: 0.6186 - acc: 0.8237 - val_loss: 0.5232 - val_acc: 0.8486\n", | |
"Epoch 4/10\n", | |
"12895/12895 [==============================] - 17s - loss: 0.3897 - acc: 0.8955 - val_loss: 0.4628 - val_acc: 0.8758\n", | |
"Epoch 5/10\n", | |
"12895/12895 [==============================] - 17s - loss: 0.2638 - acc: 0.9288 - val_loss: 0.4571 - val_acc: 0.8793\n", | |
"Epoch 6/10\n", | |
"12895/12895 [==============================] - 17s - loss: 0.1820 - acc: 0.9530 - val_loss: 0.5091 - val_acc: 0.8779\n", | |
"Epoch 7/10\n", | |
"12895/12895 [==============================] - 17s - loss: 0.1192 - acc: 0.9684 - val_loss: 0.5027 - val_acc: 0.8786\n", | |
"Epoch 00006: early stopping\n" | |
] | |
} | |
], | |
"source": [ | |
"earlystop = EarlyStopping(monitor='val_loss', patience=1, verbose=1)\n", | |
"result = model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, \n", | |
" validation_split=0.1, show_accuracy=True,callbacks=[earlystop])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 377, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"3582/3582 [==============================] - 1s \n", | |
"('Test score:', 0.44877443717618948)\n", | |
"3582/3582 [==============================] - 1s \n", | |
"('Test accuracy:', 0.88972640982691231)\n" | |
] | |
} | |
], | |
"source": [ | |
"score = earlystop.model.evaluate(X_test, Y_test, batch_size=batch_size)\n", | |
"print('Test score:', score)\n", | |
"classes = earlystop.model.predict_classes(X_test, batch_size=batch_size)\n", | |
"acc = np_utils.accuracy(classes, test_y) # 要用没有转换前的y\n", | |
"print('Test accuracy:', acc)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 378, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO (theano.gof.compilelock): Refreshing lock /home/openmind/.theano/compiledir_Linux-3.19--generic-x86_64-with-debian-jessie-sid-x86_64-2.7.10-64/lock_dir/lock\n", | |
"INFO:theano.gof.compilelock:Refreshing lock /home/openmind/.theano/compiledir_Linux-3.19--generic-x86_64-with-debian-jessie-sid-x86_64-2.7.10-64/lock_dir/lock\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Build model...\n" | |
] | |
} | |
], | |
"source": [ | |
"# LSTM\n", | |
"print('Build model...')\n", | |
"model = Sequential()\n", | |
"\n", | |
"# 词向量嵌入层,输入:词典大小,词向量大小,文本长度\n", | |
"model.add(Embedding(maxfeatures, word_dim,input_length=maxlen)) \n", | |
"#model.add(Dropout(0.25))\n", | |
"model.add(LSTM(100)) \n", | |
"model.add(Flatten())\n", | |
"# 全连接层\n", | |
"model.add(Dense(hidden_dims))\n", | |
"model.add(Dropout(0.25))\n", | |
"model.add(Activation('relu'))\n", | |
"model.add(Dense(nb_classes))\n", | |
"model.add(Activation('softmax'))\n", | |
"model.compile(loss='categorical_crossentropy', optimizer='rmsprop')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"result = model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=1, \n", | |
" validation_split=0.1, show_accuracy=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 338, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO (theano.gof.compilelock): Refreshing lock /home/openmind/.theano/compiledir_Linux-3.19--generic-x86_64-with-debian-jessie-sid-x86_64-2.7.10-64/lock_dir/lock\n", | |
"INFO:theano.gof.compilelock:Refreshing lock /home/openmind/.theano/compiledir_Linux-3.19--generic-x86_64-with-debian-jessie-sid-x86_64-2.7.10-64/lock_dir/lock\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Build model...\n" | |
] | |
} | |
], | |
"source": [ | |
"# CNN + LSTM\n", | |
"print('Build model...')\n", | |
"model = Sequential()\n", | |
"\n", | |
"# 词向量嵌入层,输入:词典大小,词向量大小,文本长度\n", | |
"model.add(Embedding(maxfeatures, word_dim,input_length=maxlen)) \n", | |
"model.add(Dropout(0.25))\n", | |
"model.add(Convolution1D(nb_filter=nb_filter,\n", | |
" filter_length=filter_length,\n", | |
" border_mode=\"valid\",\n", | |
" activation=\"relu\"))\n", | |
"# 池化层\n", | |
"model.add(MaxPooling1D(pool_length=pool_length))\n", | |
"# lstm\n", | |
"model.add(LSTM(100))\n", | |
"# 全连接层\n", | |
"#model.add(Flatten())\n", | |
"model.add(Dropout(0.25))\n", | |
"model.add(Activation('relu'))\n", | |
"model.add(Dense(nb_classes))\n", | |
"model.add(Activation('softmax'))\n", | |
"model.compile(loss='categorical_crossentropy', optimizer='rmsprop')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 371, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Build model...\n" | |
] | |
} | |
], | |
"source": [ | |
"# 整合3个cnn\n", | |
"from keras.models import Graph\n", | |
"fw = [2,10, 5]\n", | |
"pool_length = [2,50, 10]\n", | |
"print('Build model...')\n", | |
"graph = Graph()\n", | |
"graph.add_input(name='input', input_shape=(maxlen,), dtype=int)\n", | |
"graph.add_node(Embedding(maxfeatures, word_dim, input_length=maxlen),\n", | |
" name='embedding', input='input')\n", | |
"\n", | |
"# 卷积2个字\n", | |
"graph.add_node(Convolution1D(nb_filter=nb_filter,filter_length=fw[0],\n", | |
" activation=\"relu\"),\n", | |
" name='conv1', input='embedding') \n", | |
"graph.add_node(MaxPooling1D(pool_length =pool_length[0], ignore_border = False), name='pool1', input = 'conv1')\n", | |
"graph.add_node(Flatten(), name='flat1', input='conv1')\n", | |
"\n", | |
"\n", | |
"# 卷积10个字\n", | |
"graph.add_node(Convolution1D(nb_filter=nb_filter,filter_length=fw[1],\n", | |
" activation=\"relu\"),\n", | |
" name='conv2', input='embedding') \n", | |
"graph.add_node(MaxPooling1D(pool_length =pool_length[1], ignore_border = False), name='pool2', input = 'conv2')\n", | |
"graph.add_node(Flatten(), name='flat2', input='conv2')\n", | |
"\n", | |
"#卷积5个字\n", | |
"graph.add_node(Convolution1D(nb_filter=nb_filter,filter_length=fw[2],\n", | |
" activation=\"relu\"),\n", | |
" name='conv3', input='embedding') \n", | |
"graph.add_node(MaxPooling1D(pool_length =pool_length[2], ignore_border = False), name='pool3', input = 'conv3')\n", | |
"graph.add_node(Flatten(), name='flat3', input='conv3')\n", | |
"\n", | |
"\n", | |
"# 整合\n", | |
"graph.add_node(Dense(hidden_dims,activation='relu'), name='dense1', \n", | |
" inputs=['flat1', 'flat2', 'flat3'], merge_mode='concat')\n", | |
"graph.add_node(Dropout(0.5), name='drop1', input='dense1')\n", | |
"graph.add_node(Dense(nb_classes, activation='softmax'), name='softmax', input='drop1')\n", | |
"graph.add_output(name='output', input='softmax')\n", | |
"graph.compile('Adam', loss = {'output': 'categorical_crossentropy'})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 372, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Train on 12895 samples, validate on 1433 samples\n", | |
"Epoch 1/3\n", | |
"12895/12895 [==============================] - 37s - loss: 1.4137 - val_loss: 0.5643\n", | |
"Epoch 2/3\n", | |
"12895/12895 [==============================] - 37s - loss: 0.6103 - val_loss: 0.4328\n", | |
"Epoch 3/3\n", | |
"12895/12895 [==============================] - 37s - loss: 0.3905 - val_loss: 0.4688\n" | |
] | |
} | |
], | |
"source": [ | |
"result = graph.fit({'input':X_train, 'output':Y_train}, \n", | |
" nb_epoch=3,batch_size=batch_size,\n", | |
" validation_split=0.1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 370, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"('Test accuracy:', 0.89335566722501392)\n" | |
] | |
} | |
], | |
"source": [ | |
"predict = graph.predict({'input':X_test}, batch_size=batch_size)\n", | |
"predict = predict['output']\n", | |
"classes = predict.argmax(axis=1)\n", | |
"acc = np_utils.accuracy(classes, test_y) # 要用没有转换前的y\n", | |
"print('Test accuracy:', acc)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment