Last active
August 29, 2015 14:25
-
-
Save hiropppe/122fe9e5b44506866d62 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from splearn.rdd import ArrayRDD | |
from splearn.rdd import DictRDD | |
from splearn.feature_extraction.text import SparkCountVectorizer | |
from splearn.feature_extraction.text import SparkHashingVectorizer | |
from splearn.feature_extraction.text import SparkTfidfTransformer | |
from splearn.naive_bayes import SparkMultinomialNB | |
from splearn.naive_bayes import SparkGaussianNB | |
from splearn.svm import SparkLinearSVC | |
from splearn.linear_model import SparkSGDClassifier | |
from splearn.linear_model import SparkLogisticRegression | |
from splearn.pipeline import SparkPipeline | |
from splearn.grid_search import SparkGridSearchCV | |
import MeCab | |
default_stop_pos = ['接続詞', '助詞', '助動詞', '記号'] | |
def mecab_analyzer(text, stop_pos=default_stop_pos): | |
mecab = MeCab.Tagger('-Ochasen') | |
encoded_text = text.encode('utf-8') | |
node = mecab.parseToNode(encoded_text) | |
node = node.next | |
word = [] | |
while node: | |
surface = node.surface | |
feature_array = node.feature.split(',') | |
if feature_array[0] == 'BOS/EOS' or feature_array[0] in stop_pos: | |
node = node.next | |
continue | |
if feature_array[6] == '*': | |
w = surface | |
else: | |
w = feature_array[6] | |
word.append(w.decode('utf-8')) | |
node = node.next | |
return word | |
pos_text = sc.wholeTextFiles("hdfs://hdp1.containers.dev:9000/user/root/data/binary_clf/all/1") | |
neg_text = sc.wholeTextFiles("hdfs://hdp1.containers.dev:9000/user/root/data/binary_clf/all/0") | |
xy = pos_text.map(lambda x: (x[1], 1)).union(neg_text.map(lambda x: (x[1], 0))).map(lambda x: (x, np.random.rand())).sortBy(lambda x: x[1]).map(lambda x: x[0]) | |
fold = 3 | |
# split rdd | |
data = [] | |
as_list = xy.collect() | |
size = len(as_list)/fold | |
for i in range(fold-1): | |
data.append(as_list[size*i:size*(i+1)]) | |
data.append(as_list[size*(i+1):]) | |
# cv | |
cv_accuracy = [] | |
cv_precision = [] | |
cv_recall = [] | |
for i in range(fold): | |
print 'iteration', i | |
train = [] | |
for k in range(fold): | |
data_idx = k+i | |
if(fold <= data_idx): | |
data_idx -= fold | |
if(k < fold - 1): | |
print data_idx, 'train data', | |
train.extend(data[data_idx]) | |
else: | |
print data_idx, 'test data' | |
test = data[data_idx] | |
# train data | |
train_x = sc.parallelize(train).map(lambda x: x[0]) | |
train_y = sc.parallelize(train).map(lambda x: x[1]) | |
train_x = ArrayRDD(train_x) | |
train_y = ArrayRDD(train_y) | |
Z = DictRDD((train_x, train_y), columns=('X', 'y'), dtype=[np.ndarray, np.ndarray]) | |
# pipeline | |
dist_pipeline = SparkPipeline(( | |
# ('vect', SparkCountVectorizer(analyzer=mecab_analyzer)), # countTF | |
# ('vect', SparkHashingVectorizer(analyzer=mecab_analyzer)), # hashingTF | |
('vect', SparkHashingVectorizer(analyzer=mecab_analyzer, non_negative=True)), # hashingTF for NB | |
('tfidf', SparkTfidfTransformer()), # IDF | |
# ('clf', SparkLinearSVC(C=1.0)) # SVC | |
('clf', SparkMultinomialNB(alpha=0.05)) # NB | |
)) | |
# fit | |
dist_pipeline.fit(Z, clf__classes=np.array([0, 1])) | |
# test data | |
test_x = ArrayRDD(sc.parallelize(test).map(lambda x: x[0])) | |
test_y = ArrayRDD(sc.parallelize(test).map(lambda x: x[1])) | |
test_Z = DictRDD((test_x, test_y), columns=('X', 'y'), dtype=[np.ndarray, np.ndarray]) | |
# predict test data | |
predicts = dist_pipeline.predict(test_Z[:, 'X']) | |
# metrics(accuracy, precision, recall, f1) | |
data_size = len(test) | |
array_y = test_y.toarray() | |
array_pred = predicts.toarray() | |
y_and_pred = zip(array_y, array_pred) | |
pos_size = sum(array_y) | |
neg_size = data_size - pos_size | |
pos_pred_size = sum(array_pred) | |
neg_pred_size = data_size - pos_pred_size | |
pos_acc_size = len(filter(lambda x: x[0] == 1 and x[0] == x[1], y_and_pred)) | |
neg_acc_size = len(filter(lambda x: x[0] == 0 and x[0] == x[1], y_and_pred)) | |
acc_size = pos_acc_size + neg_acc_size | |
accuracy = acc_size / float(data_size) | |
precision = pos_acc_size / float(pos_pred_size) | |
recall = pos_acc_size / float(pos_size) | |
f1 = 2 * (precision * recall) / (precision + recall) | |
score = {'accuracy': accuracy, 'recall': recall, 'precision': precision, 'f1': f1} | |
print score | |
cv_accuracy.append(score['accuracy']) | |
cv_precision.append(score['precision']) | |
cv_recall.append(score['recall']) | |
accuracy = sum(cv_accuracy)/len(cv_accuracy) | |
precision = sum(cv_precision)/len(cv_precision) | |
recall = sum(cv_recall)/len(cv_recall) | |
cv_score = {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': 2 * (precision * recall) / (precision + recall)} | |
print cv_score |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment