Created
November 25, 2015 02:56
-
-
Save xccds/d82f2b8de4112c209f8c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"collapsed": true | |
}, | |
"source": [ | |
"### 基于Deep LSTM的中文分词\n", | |
" - 步骤1:读入有标注的训练语料库,处理成keras需要的数据格式。\n", | |
" - 步骤2:根据训练数据建模,使用deep LSTM方法\n", | |
" - 步骤3:读入无标注的检验语料库,用两层LSTM模型进行分词标注,用更多层效果会更好一点\n", | |
" - 步骤4:检查最终的效果 F值0.949" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"- 步骤1:训练数据读取和转换" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled)\n" | |
] | |
} | |
], | |
"source": [ | |
"import sys\n", | |
"import os\n", | |
"import nltk \n", | |
"import codecs\n", | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"from os import path\n", | |
"from nltk.probability import FreqDist \n", | |
"from gensim.models import word2vec\n", | |
"from cPickle import load, dump" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 根据微软语料库计算词向量\n", | |
"\n", | |
"# 读单个文本\n", | |
"def load_file(input_file):\n", | |
" input_data = codecs.open(input_file, 'r', 'utf-8')\n", | |
" input_text = input_data.read()\n", | |
" return input_text\n", | |
"\n", | |
"# 读取目录下文本成为一个字符串\n", | |
"def load_dir(input_dir):\n", | |
" files = list_dir(input_dir)\n", | |
" seg_files_path = [path.join(input_dir, f) for f in files]\n", | |
" output = []\n", | |
" for txt in seg_files_path:\n", | |
" output.append(load_file(txt))\n", | |
" return '\\n'.join(output)\n", | |
"\n", | |
"# nltk 输入文本,输出词频表\n", | |
"def freq_func(input_txt):\n", | |
" corpus = nltk.Text(input_txt) \n", | |
" fdist = FreqDist(corpus) \n", | |
" w = fdist.keys() \n", | |
" v = fdist.values() \n", | |
" freqdf = pd.DataFrame({'word':w,'freq':v}) \n", | |
" freqdf.sort('freq',ascending =False, inplace=True)\n", | |
" freqdf['idx'] = np.arange(len(v))\n", | |
" return freqdf\n", | |
"\n", | |
"# word2vec建模\n", | |
"def trainW2V(corpus, epochs=20, num_features = 100,sg=1,\\\n", | |
" min_word_count = 1, num_workers = 4,\\\n", | |
" context = 4, sample = 1e-5, negative = 5):\n", | |
" w2v = word2vec.Word2Vec(workers = num_workers,\n", | |
" sample = sample,\n", | |
" size = num_features,\n", | |
" min_count=min_word_count,\n", | |
" window = context)\n", | |
" np.random.shuffle(corpus)\n", | |
" w2v.build_vocab(corpus) \n", | |
" for epoch in range(epochs):\n", | |
" print('epoch' + str(epoch))\n", | |
" np.random.shuffle(corpus)\n", | |
" w2v.train(corpus)\n", | |
" w2v.alpha *= 0.9 \n", | |
" w2v.min_alpha = w2v.alpha \n", | |
" print(\"word2vec DONE.\")\n", | |
" return w2v\n", | |
" \n", | |
"\n", | |
"def save_w2v(w2v, idx2word):\n", | |
" # 保存词向量lookup矩阵,按idx位置存放\n", | |
" init_weight_wv = []\n", | |
" for i in range(len(idx2word)):\n", | |
" init_weight_wv.append(w2v[idx2word[i]])\n", | |
" return init_weight_wv\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch0\n", | |
"epoch1\n", | |
"epoch2\n", | |
"epoch3\n", | |
"epoch4\n", | |
"epoch5\n", | |
"epoch6\n", | |
"epoch7\n", | |
"epoch8\n", | |
"epoch9\n", | |
"epoch10\n", | |
"epoch11\n", | |
"epoch12\n", | |
"epoch13\n", | |
"epoch14\n", | |
"epoch15\n", | |
"epoch16\n", | |
"epoch17\n", | |
"epoch18\n", | |
"epoch19\n", | |
"word2vec DONE.\n" | |
] | |
} | |
], | |
"source": [ | |
"input_file = 'icwb2-data/training/msr_training.utf8'\n", | |
"input_text = load_file(input_file) # 读入全部文本\n", | |
"txtwv = [line.split() for line in input_text.split('\\n') if line != ''] # 为词向量准备的文本格式\n", | |
"txtnltk = [w for w in input_text.split()] # 为计算词频准备的文本格式\n", | |
"freqdf = freq_func(txtnltk) # 计算词频表\n", | |
"maxfeatures = freqdf.shape[0] # 词汇个数\n", | |
"# 建立两个映射字典\n", | |
"word2idx = dict((c, i) for c, i in zip(freqdf.word, freqdf.idx))\n", | |
"idx2word = dict((i, c) for c, i in zip(freqdf.word, freqdf.idx))\n", | |
"# word2vec\n", | |
"w2v = trainW2V(txtwv)\n", | |
"# 存向量\n", | |
"init_weight_wv = save_w2v(w2v,idx2word)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# 定义'U'为未登陆新字, 'P'为两头padding用途,并增加两个相应的向量表示\n", | |
"char_num = len(init_weight_wv)\n", | |
"idx2word[char_num] = u'U'\n", | |
"word2idx[u'U'] = char_num\n", | |
"idx2word[char_num+1] = u'P'\n", | |
"word2idx[u'P'] = char_num+1\n", | |
"\n", | |
"init_weight_wv.append(np.random.randn(100,))\n", | |
"init_weight_wv.append(np.zeros(100,))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 读取数据,将格式进行转换为带四种标签 S B M E\n", | |
"output_file = 'icwb2-data/training/msr_training.tagging.utf8'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"import codecs\n", | |
"import sys\n", | |
"\n", | |
"def character_tagging(input_file, output_file):\n", | |
" input_data = codecs.open(input_file, 'r', 'utf-8')\n", | |
" output_data = codecs.open(output_file, 'w', 'utf-8')\n", | |
" for line in input_data.readlines():\n", | |
" word_list = line.strip().split()\n", | |
" for word in word_list:\n", | |
" if len(word) == 1:\n", | |
" output_data.write(word + \"/S \")\n", | |
" else:\n", | |
" output_data.write(word[0] + \"/B \")\n", | |
" for w in word[1:len(word)-1]:\n", | |
" output_data.write(w + \"/M \")\n", | |
" output_data.write(word[len(word)-1] + \"/E \")\n", | |
" output_data.write(\"\\n\")\n", | |
" input_data.close()\n", | |
" output_data.close()\n", | |
"\n", | |
"character_tagging(input_file, output_file)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 分离word 和 label\n", | |
"with open(output_file) as f:\n", | |
" lines = f.readlines()\n", | |
" train_line = [[w[0] for w in line.decode('utf-8').split()] for line in lines]\n", | |
" train_label = [w[2] for line in lines for w in line.decode('utf-8').split()]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 文档转数字list\n", | |
"import numpy as np\n", | |
"def sent2num(sentence, word2idx = word2idx, context = 7):\n", | |
" predict_word_num = []\n", | |
" for w in sentence:\n", | |
" # 文本中的字如果在词典中则转为数字,如果不在则设置为'U\n", | |
" if w in word2idx:\n", | |
" predict_word_num.append(word2idx[w])\n", | |
" else:\n", | |
" predict_word_num.append(word2idx[u'U'])\n", | |
" # 首尾padding\n", | |
" num = len(predict_word_num)\n", | |
" pad = int((context-1)*0.5)\n", | |
" for i in range(pad):\n", | |
" predict_word_num.insert(0,word2idx[u'P'] )\n", | |
" predict_word_num.append(word2idx[u'P'] )\n", | |
" train_x = []\n", | |
" for i in range(num):\n", | |
" train_x.append(predict_word_num[i:i+context])\n", | |
" return train_x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[[88120, 88120, 88120, 9, 21, 107, 1976],\n", | |
" [88120, 88120, 9, 21, 107, 1976, 26],\n", | |
" [88120, 9, 21, 107, 1976, 26, 1116],\n", | |
" [9, 21, 107, 1976, 26, 1116, 1397],\n", | |
" [21, 107, 1976, 26, 1116, 1397, 7],\n", | |
" [107, 1976, 26, 1116, 1397, 7, 10],\n", | |
" [1976, 26, 1116, 1397, 7, 10, 538],\n", | |
" [26, 1116, 1397, 7, 10, 538, 2300],\n", | |
" [1116, 1397, 7, 10, 538, 2300, 2353],\n", | |
" [1397, 7, 10, 538, 2300, 2353, 378],\n", | |
" [7, 10, 538, 2300, 2353, 378, 0],\n", | |
" [10, 538, 2300, 2353, 378, 0, 46],\n", | |
" [538, 2300, 2353, 378, 0, 46, 2118],\n", | |
" [2300, 2353, 378, 0, 46, 2118, 18],\n", | |
" [2353, 378, 0, 46, 2118, 18, 2610],\n", | |
" [378, 0, 46, 2118, 18, 2610, 1],\n", | |
" [0, 46, 2118, 18, 2610, 1, 1172],\n", | |
" [46, 2118, 18, 2610, 1, 1172, 2183],\n", | |
" [2118, 18, 2610, 1, 1172, 2183, 78],\n", | |
" [18, 2610, 1, 1172, 2183, 78, 7],\n", | |
" [2610, 1, 1172, 2183, 78, 7, 16],\n", | |
" [1, 1172, 2183, 78, 7, 16, 141],\n", | |
" [1172, 2183, 78, 7, 16, 141, 70],\n", | |
" [2183, 78, 7, 16, 141, 70, 110],\n", | |
" [78, 7, 16, 141, 70, 110, 1],\n", | |
" [7, 16, 141, 70, 110, 1, 2300],\n", | |
" [16, 141, 70, 110, 1, 2300, 2353],\n", | |
" [141, 70, 110, 1, 2300, 2353, 378],\n", | |
" [70, 110, 1, 2300, 2353, 378, 0],\n", | |
" [110, 1, 2300, 2353, 378, 0, 87],\n", | |
" [1, 2300, 2353, 378, 0, 87, 3986],\n", | |
" [2300, 2353, 378, 0, 87, 3986, 1962],\n", | |
" [2353, 378, 0, 87, 3986, 1962, 7],\n", | |
" [378, 0, 87, 3986, 1962, 7, 462],\n", | |
" [0, 87, 3986, 1962, 7, 462, 180],\n", | |
" [87, 3986, 1962, 7, 462, 180, 91],\n", | |
" [3986, 1962, 7, 462, 180, 91, 1962],\n", | |
" [1962, 7, 462, 180, 91, 1962, 1],\n", | |
" [7, 462, 180, 91, 1962, 1, 1384],\n", | |
" [462, 180, 91, 1962, 1, 1384, 37],\n", | |
" [180, 91, 1962, 1, 1384, 37, 1],\n", | |
" [91, 1962, 1, 1384, 37, 1, 43],\n", | |
" [1962, 1, 1384, 37, 1, 43, 463],\n", | |
" [1, 1384, 37, 1, 43, 463, 1380],\n", | |
" [1384, 37, 1, 43, 463, 1380, 2],\n", | |
" [37, 1, 43, 463, 1380, 2, 88120],\n", | |
" [1, 43, 463, 1380, 2, 88120, 88120],\n", | |
" [43, 463, 1380, 2, 88120, 88120, 88120]]" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# 输入字符list,输出数字list\n", | |
"sent2num(train_line[0])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# 将所有训练文本转成数字list\n", | |
"train_word_num = []\n", | |
"for line in train_line:\n", | |
" train_word_num.extend(sent2num(line))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"4050469\n", | |
"4050469\n" | |
] | |
} | |
], | |
"source": [ | |
"print len(train_word_num)\n", | |
"print len(train_label)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from cPickle import load\n", | |
"#dump(train_word_num, open('train_word_num.pickle', 'wb'))\n", | |
"#train_word_num = load(open('train_word_num.pickle','rb'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"nb_classes = len(np.unique(train_label))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"- 步骤3:训练模型" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 56, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from __future__ import absolute_import\n", | |
"from __future__ import print_function\n", | |
"\n", | |
"from keras.preprocessing import sequence\n", | |
"from keras.optimizers import SGD, RMSprop, Adagrad\n", | |
"from keras.utils import np_utils\n", | |
"from keras.models import Sequential,Graph\n", | |
"from keras.layers.core import Dense, Dropout, Activation, TimeDistributedDense\n", | |
"from keras.layers.embeddings import Embedding\n", | |
"from keras.layers.recurrent import LSTM, GRU,SimpleRNN\n", | |
"from keras.layers.core import Reshape, Flatten ,Dropout\n", | |
"from keras.regularizers import l1,l2\n", | |
"from keras.layers.convolutional import Convolution2D, MaxPooling2D,MaxPooling1D" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{u'M': 2, u'S': 3, u'B': 0, u'E': 1}\n", | |
"{0: u'B', 1: u'E', 2: u'M', 3: u'S'}\n" | |
] | |
} | |
], | |
"source": [ | |
"# 建立两个字典\n", | |
"label_dict = dict(zip(np.unique(train_label), range(4)))\n", | |
"num_dict = {n:l for l,n in label_dict.iteritems()}\n", | |
"print(label_dict)\n", | |
"print(num_dict)\n", | |
"# 将目标变量转为数字\n", | |
"train_label = [label_dict[y] for y in train_label]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 切分数据集\n", | |
"from sklearn.cross_validation import train_test_split\n", | |
"train_word_num = np.array(train_word_num)\n", | |
"train_X, test_X, train_y, test_y = train_test_split(train_word_num, train_label , train_size=0.9, random_state=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"Y_train = np_utils.to_categorical(train_y, nb_classes)\n", | |
"Y_test = np_utils.to_categorical(test_y, nb_classes)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"3645422 train sequences\n", | |
"405047 test sequences\n" | |
] | |
} | |
], | |
"source": [ | |
"print(len(train_X), 'train sequences')\n", | |
"print(len(test_X), 'test sequences')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 初始字向量格式准备\n", | |
"init_weight = [np.array(init_weight_wv)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"batch_size = 128" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"maxfeatures = init_weight[0].shape[0] # 词典大小\n", | |
"word_dim = 100\n", | |
"maxlen = 7\n", | |
"hidden_units = 100" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# stacking LSTM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"stacking LSTM...\n" | |
] | |
} | |
], | |
"source": [ | |
"print('stacking LSTM...')\n", | |
"model = Sequential()\n", | |
"model.add(Embedding(maxfeatures, word_dim,input_length=maxlen))\n", | |
"model.add(LSTM(output_dim=hidden_units, return_sequences =True))\n", | |
"model.add(LSTM(output_dim=hidden_units, return_sequences =False))\n", | |
"model.add(Dropout(0.5))\n", | |
"model.add(Dense(nb_classes))\n", | |
"model.add(Activation('softmax'))\n", | |
"model.compile(loss='categorical_crossentropy', optimizer='adam')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Train...\n", | |
"Train on 3645422 samples, validate on 405047 samples\n", | |
"Epoch 1/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.2642 - acc: 0.9054 - val_loss: 0.1585 - val_acc: 0.9446\n", | |
"Epoch 2/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.1341 - acc: 0.9549 - val_loss: 0.1166 - val_acc: 0.9603\n", | |
"Epoch 3/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.0993 - acc: 0.9671 - val_loss: 0.0976 - val_acc: 0.9673\n", | |
"Epoch 4/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.0813 - acc: 0.9733 - val_loss: 0.0870 - val_acc: 0.9711\n", | |
"Epoch 5/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.0703 - acc: 0.9769 - val_loss: 0.0818 - val_acc: 0.9728\n", | |
"Epoch 6/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.0630 - acc: 0.9792 - val_loss: 0.0760 - val_acc: 0.9749\n", | |
"Epoch 7/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.0579 - acc: 0.9810 - val_loss: 0.0723 - val_acc: 0.9764\n", | |
"Epoch 8/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.0541 - acc: 0.9821 - val_loss: 0.0698 - val_acc: 0.9772\n", | |
"Epoch 9/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.0512 - acc: 0.9832 - val_loss: 0.0697 - val_acc: 0.9774\n", | |
"Epoch 10/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.0493 - acc: 0.9838 - val_loss: 0.0678 - val_acc: 0.9781\n", | |
"Epoch 11/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.0474 - acc: 0.9844 - val_loss: 0.0664 - val_acc: 0.9788\n", | |
"Epoch 12/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.0458 - acc: 0.9849 - val_loss: 0.0659 - val_acc: 0.9787\n", | |
"Epoch 13/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.0446 - acc: 0.9853 - val_loss: 0.0647 - val_acc: 0.9797\n", | |
"Epoch 14/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.0435 - acc: 0.9857 - val_loss: 0.0627 - val_acc: 0.9798\n", | |
"Epoch 15/20\n", | |
"3645422/3645422 [==============================] - 648s - loss: 0.0426 - acc: 0.9860 - val_loss: 0.0625 - val_acc: 0.9799\n", | |
"Epoch 16/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.0420 - acc: 0.9862 - val_loss: 0.0615 - val_acc: 0.9804\n", | |
"Epoch 17/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.0411 - acc: 0.9864 - val_loss: 0.0601 - val_acc: 0.9804\n", | |
"Epoch 18/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.0406 - acc: 0.9866 - val_loss: 0.0614 - val_acc: 0.9810\n", | |
"Epoch 19/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.0401 - acc: 0.9867 - val_loss: 0.0624 - val_acc: 0.9804\n", | |
"Epoch 20/20\n", | |
"3645422/3645422 [==============================] - 649s - loss: 0.0398 - acc: 0.9869 - val_loss: 0.0618 - val_acc: 0.9807\n" | |
] | |
} | |
], | |
"source": [ | |
"# train_X, test_X, Y_train, Y_test\n", | |
"print(\"Train...\")\n", | |
"result = model.fit(train_X, Y_train, batch_size=batch_size, \n", | |
" nb_epoch=20, validation_data = (test_X,Y_test), show_accuracy=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# bidirectional_lstm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"graph = Graph()\n", | |
"graph.add_input(name='input', input_shape=(maxlen,), dtype=int)\n", | |
"graph.add_node(Embedding(maxfeatures, word_dim, input_length=maxlen),\n", | |
" name='embedding', input='input')\n", | |
"graph.add_node(LSTM(output_dim=hidden_units), name='forward', input='embedding')\n", | |
"graph.add_node(LSTM(output_dim=hidden_units, go_backwards =True), name='backward', input='embedding')\n", | |
"graph.add_node(Dropout(0.5), name='dropout', inputs=['forward', 'backward'])\n", | |
"graph.add_node(Dense(nb_classes, activation='softmax'), name='softmax', input='dropout')\n", | |
"graph.add_output(name='output', input='softmax')\n", | |
"graph.compile(loss = {'output': 'categorical_crossentropy'}, optimizer='adam')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Train on 3645422 samples, validate on 405047 samples\n", | |
"Epoch 1/20\n", | |
"3645422/3645422 [==============================] - 703s - loss: 0.2795 - val_loss: 0.1710\n", | |
"Epoch 2/20\n", | |
"3645422/3645422 [==============================] - 701s - loss: 0.1448 - val_loss: 0.1225\n", | |
"Epoch 3/20\n", | |
"3645422/3645422 [==============================] - 702s - loss: 0.1090 - val_loss: 0.1029\n", | |
"Epoch 4/20\n", | |
"3645422/3645422 [==============================] - 701s - loss: 0.0906 - val_loss: 0.0930\n", | |
"Epoch 5/20\n", | |
"3645422/3645422 [==============================] - 701s - loss: 0.0794 - val_loss: 0.0876\n", | |
"Epoch 6/20\n", | |
"3645422/3645422 [==============================] - 701s - loss: 0.0727 - val_loss: 0.0841\n", | |
"Epoch 7/20\n", | |
"3645422/3645422 [==============================] - 702s - loss: 0.0677 - val_loss: 0.0809\n", | |
"Epoch 8/20\n", | |
"3645422/3645422 [==============================] - 702s - loss: 0.0640 - val_loss: 0.0789\n", | |
"Epoch 9/20\n", | |
"3645422/3645422 [==============================] - 702s - loss: 0.0613 - val_loss: 0.0762\n", | |
"Epoch 10/20\n", | |
"3645422/3645422 [==============================] - 701s - loss: 0.0592 - val_loss: 0.0760\n", | |
"Epoch 11/20\n", | |
"3645422/3645422 [==============================] - 702s - loss: 0.0574 - val_loss: 0.0761\n", | |
"Epoch 12/20\n", | |
"3645422/3645422 [==============================] - 701s - loss: 0.0563 - val_loss: 0.0727\n", | |
"Epoch 13/20\n", | |
"3645422/3645422 [==============================] - 702s - loss: 0.0552 - val_loss: 0.0732\n", | |
"Epoch 14/20\n", | |
"3645422/3645422 [==============================] - 702s - loss: 0.0542 - val_loss: 0.0717\n", | |
"Epoch 15/20\n", | |
"3645422/3645422 [==============================] - 701s - loss: 0.0533 - val_loss: 0.0735\n", | |
"Epoch 16/20\n", | |
"3645422/3645422 [==============================] - 700s - loss: 0.0525 - val_loss: 0.0707\n", | |
"Epoch 17/20\n", | |
"3645422/3645422 [==============================] - 702s - loss: 0.0520 - val_loss: 0.0716\n", | |
"Epoch 18/20\n", | |
"3645422/3645422 [==============================] - 701s - loss: 0.0517 - val_loss: 0.0711\n", | |
"Epoch 19/20\n", | |
"3645422/3645422 [==============================] - 701s - loss: 0.0510 - val_loss: 0.0700\n", | |
"Epoch 20/20\n", | |
"3645422/3645422 [==============================] - 700s - loss: 0.0506 - val_loss: 0.0715\n" | |
] | |
} | |
], | |
"source": [ | |
"result2 = graph.fit({'input':train_X, 'output':Y_train}, \n", | |
" batch_size=batch_size, \n", | |
" nb_epoch=20, validation_data = ({'input':test_X,'output':Y_test}))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(88121, 100)\n", | |
"(100, 100)\n", | |
"(100, 100)\n", | |
"(100,)\n", | |
"(100, 100)\n", | |
"(100, 100)\n", | |
"(100,)\n", | |
"(100, 100)\n", | |
"(100, 100)\n", | |
"(100,)\n", | |
"(100, 100)\n", | |
"(100, 100)\n", | |
"(100,)\n", | |
"(100, 100)\n", | |
"(100, 100)\n", | |
"(100,)\n", | |
"(100, 100)\n", | |
"(100, 100)\n", | |
"(100,)\n", | |
"(100, 100)\n", | |
"(100, 100)\n", | |
"(100,)\n", | |
"(100, 100)\n", | |
"(100, 100)\n", | |
"(100,)\n", | |
"(100, 4)\n", | |
"(4,)\n" | |
] | |
} | |
], | |
"source": [ | |
"import theano\n", | |
"# 模型待学习参数如下:即W的规模\n", | |
"weight_len = len(model.get_weights())\n", | |
"for i in range(weight_len):\n", | |
" print(model.get_weights()[i].shape)\n", | |
"# lstm 分别是embeding, 4种不同的门 , dense_w, dense_b" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 90, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(10, 4)" | |
] | |
}, | |
"execution_count": 90, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"layer = theano.function([model.layers[0].input],model.layers[3].get_output(train=False),allow_input_downcast=True) \n", | |
"layer_out = layer(test_X[:10]) \n", | |
"layer_out.shape # 前10篇文章的第0层输出,经过reku计算" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"- 步骤4:用test文本进行预测,评估效果" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"temp_txt = u'国家食药监总局发布通知称,酮康唑口服制剂因存在严重肝毒性不良反应,即日起停止生产销售使用。'\n", | |
"temp_txt = list(temp_txt)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[[88120, 88120, 88120, 309, 223, 5082, 1522],\n", | |
" [88120, 88120, 309, 223, 5082, 1522, 6778],\n", | |
" [88120, 309, 223, 5082, 1522, 6778, 396],\n", | |
" [309, 223, 5082, 1522, 6778, 396, 1773],\n", | |
" [223, 5082, 1522, 6778, 396, 1773, 1053]]" | |
] | |
}, | |
"execution_count": 39, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"temp_num = sent2num(temp_txt)\n", | |
"temp_num[:5]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 49, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 根据输入得到标注推断\n", | |
"def predict_num(input_num,input_txt, \\\n", | |
" model,\\\n", | |
" label_dict=label_dict,\\\n", | |
" num_dict=num_dict):\n", | |
" input_num = np.array(input_num)\n", | |
" predict_prob = model.predict_proba(input_num, verbose=False)\n", | |
" predict_lable = model.predict_classes(input_num, verbose=False)\n", | |
" for i , lable in enumerate(predict_lable[:-1]):\n", | |
" # 如果是首字 ,不可为E, M\n", | |
" if i==0:\n", | |
" predict_prob[i, label_dict[u'E']] = 0\n", | |
" predict_prob[i, label_dict[u'M']] = 0 \n", | |
" # 前字为B,后字不可为B,S\n", | |
" if lable == label_dict[u'B']:\n", | |
" predict_prob[i+1,label_dict[u'B']] = 0\n", | |
" predict_prob[i+1,label_dict[u'S']] = 0\n", | |
" # 前字为E,后字不可为M,E\n", | |
" if lable == label_dict[u'E']:\n", | |
" predict_prob[i+1,label_dict[u'M']] = 0\n", | |
" predict_prob[i+1,label_dict[u'E']] = 0\n", | |
" # 前字为M,后字不可为B,S\n", | |
" if lable == label_dict[u'M']:\n", | |
" predict_prob[i+1,label_dict[u'B']] = 0\n", | |
" predict_prob[i+1,label_dict[u'S']] = 0\n", | |
" # 前字为S,后字不可为M,E\n", | |
" if lable == label_dict[u'S']:\n", | |
" predict_prob[i+1,label_dict[u'M']] = 0\n", | |
" predict_prob[i+1,label_dict[u'E']] = 0\n", | |
" predict_lable[i+1] = predict_prob[i+1].argmax()\n", | |
" predict_lable_new = [num_dict[x] for x in predict_lable]\n", | |
" result = [w+'/' +l for w, l in zip(input_txt,predict_lable_new)]\n", | |
" return ' '.join(result) + '\\n'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 50, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"国/B 家/M 食/M 药/M 监/M 总/M 局/E 发/B 布/E 通/B 知/E 称/S ,/S 酮/B 康/E 唑/B 口/E 服/S 制/B 剂/E 因/S 存/B 在/E 严/B 重/E 肝/S 毒/B 性/E 不/B 良/E 反/B 应/E ,/S 即/B 日/E 起/S 停/B 止/E 生/B 产/E 销/B 售/E 使/B 用/E 。/S\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"temp = predict_num(temp_num,temp_txt, model = model)\n", | |
"print(temp)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 51, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"test_file = 'icwb2-data/testing/msr_test.utf8'\n", | |
"with open(test_file,'r') as f:\n", | |
" lines = f.readlines()\n", | |
" test_texts = [list(line.decode('utf-8').strip()) for line in lines]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 52, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"test_output = []\n", | |
"for line in test_texts:\n", | |
" test_num = sent2num(line)\n", | |
" output_line = predict_num(test_num,input_txt=line,model = model)\n", | |
" test_output.append(output_line.encode('utf-8'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"with open('icwb2-data/testing/msr_test_output.utf8','w') as f:\n", | |
" f.writelines(test_output)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 46, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"input_file = 'icwb2-data/testing/msr_test_output.utf8'\n", | |
"output_file = 'icwb2-data/testing/msr_test.split.tag2word.utf8'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 47, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"import codecs\n", | |
"import sys\n", | |
"\n", | |
"def character_2_word(input_file, output_file):\n", | |
" input_data = codecs.open(input_file, 'r', 'utf-8')\n", | |
" output_data = codecs.open(output_file, 'w', 'utf-8')\n", | |
" # 4 tags for character tagging: B(Begin), E(End), M(Middle), S(Single)\n", | |
" for line in input_data.readlines():\n", | |
" char_tag_list = line.strip().split()\n", | |
" for char_tag in char_tag_list:\n", | |
" char_tag_pair = char_tag.split('/')\n", | |
" char = char_tag_pair[0]\n", | |
" tag = char_tag_pair[1]\n", | |
" if tag == 'B':\n", | |
" output_data.write(' ' + char)\n", | |
" elif tag == 'M':\n", | |
" output_data.write(char)\n", | |
" elif tag == 'E':\n", | |
" output_data.write(char + ' ')\n", | |
" else: # tag == 'S'\n", | |
" output_data.write(' ' + char + ' ')\n", | |
" output_data.write(\"\\n\")\n", | |
" input_data.close()\n", | |
" output_data.close()\n", | |
"\n", | |
"character_2_word(input_file, output_file)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### - 最终使用perl脚本检验的F值为0.949" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 48, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"! ./icwb2-data/scripts/score ./icwb2-data/gold/msr_training_words.utf8 ./icwb2-data/gold/msr_test_gold.utf8 ./icwb2-data/testing/msr_test.split.tag2word.utf8 > deep.score" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment