Last active
October 9, 2015 10:49
-
-
Save Salinger/9add0c5e944e46cbf1b9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
#-*- coding:utf-8 -*- | |
### Library | |
import MySQLdb | |
import pandas.io.sql as psql | |
import pandas as pd | |
import numpy as np | |
import MeCab | |
from gensim import corpora | |
from gensim import matutils | |
from sklearn import svm | |
from sklearn import cross_validation | |
### 対象ドキュメント取得 | |
print "Fetch documents" | |
docs = psql.read_sql( | |
"SELECT text, label FROM doc_pn_testset ", | |
MySQLdb.connect( | |
host = "XXXX.rds.amazonaws.com", | |
user = "user", | |
passwd = "pass", | |
db = "textdata", | |
charset = 'utf8' | |
) | |
) | |
### 形態素に分割 | |
print "Wakati" | |
def wakati(text): | |
text = text.encode("utf-8") | |
node = tagger.parseToNode(text) | |
word_list = [] | |
while node: | |
lemma = node.feature.split(",")[6].decode("utf-8") | |
if lemma == u"*": | |
lemma = node.surface.decode("utf-8") | |
word_list.append(lemma) | |
node = node.next | |
return word_list[1:-1] | |
### BOW (Term Frequency) 素性ベクトルへ変換 | |
print "BOW" | |
def convert_feature_vector(word_list): | |
tmp = word_dict.doc2bow(word_list) | |
vec = list(matutils.corpus2dense([tmp], num_terms=len(word_dict)).T[0]) | |
return vec | |
word_dict = corpora.Dictionary(docs['spl_text']) | |
docs['feature_vector'] = docs['spl_text'].apply(convert_feature_vector) | |
### SVM による学習 | |
print "SVM: learn" | |
# Cross validation 例 | |
# for C in [2 ** n for n in range(-5, 15)]: | |
# model = svm.LinearSVC( | |
# C = C | |
# ) | |
# scores = cross_validation.cross_val_score( | |
# model, | |
# list(docs['feature_vector']), | |
# list(docs['label']), | |
# cv = 20, | |
# n_jobs = 1, | |
# ) | |
# print C, pd.Series(scores).mean() | |
model = svm.LinearSVC(class_weight = "auto", C = 0.03125) | |
model.fit(list(docs['feature_vector']), list(docs['label'])) | |
### SVM による分類 | |
print "SVM: predict" | |
sample_text = pd.Series([ | |
u"無免許運転をネット中継 逮捕 - Y!ニュース news.yahoo.co.jp/pickup/6152325 頭悪すぎ", | |
u"松谷みよ子さん亡くなったの……。", | |
u"市来さん結婚でもしかして!と思って某声優の名前をTwitterで検索したらやっぱり…", | |
u"タカ丸さんかわいいな~" | |
]) | |
split_sample_text = sample_text.apply(wakati) | |
feature_vectors = split_sample_text.apply(convert_feature_vector) | |
print model.predict(list(feature_vectors)) | |
# > [0, 0, 1, 1] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment