Last active
August 29, 2015 14:07
-
-
Save IshitaTakeshi/36d823e489f70612bae9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
*.pyc | |
*.diff |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import sys | |
import json | |
import time | |
import numpy as np | |
from naivebayes import NaiveBayes | |
from twitter_reader import get_tweets | |
class Classifier(object): | |
def __init__(self): | |
self.classifier = NaiveBayes() | |
def learn_from_tweets(self, user_ids, category): | |
tweets = get_tweets(user_ids) | |
categories = [category] * len(tweets) | |
self.classifier.fit(tweets, categories) | |
print("Training...") | |
def predict_user_input(self): | |
"""Read user input until 'exit' is entered""" | |
sentence = input("input =>") | |
while(sentence != 'exit'): | |
category = self.classifier.predict_(sentence) | |
print("category: {}".format(category)) | |
sentence = input("input =>") | |
def save(self, filename): | |
self.classifier.dump_json(filename) | |
def load(self, filename): | |
self.classifier.load_json(filename) | |
if(__name__ == '__main__'): | |
classifier = Classifier() | |
if(len(sys.argv) >= 2): | |
# load classifier settings and params | |
classifier.load(sys.argv[1]) | |
classifier.predict_user_input() | |
exit(0) | |
from config import Config | |
config = Config('settings.cfg', 'TWITTER') | |
classifier.learn_from_tweets( | |
config.true_accounts, | |
config.true_target_name | |
) | |
classifier.learn_from_tweets( | |
config.false_accounts, | |
config.false_target_name | |
) | |
# save the classifier parameters | |
classifier.save('result.json') | |
classifier.predict_user_input() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from configparser import ConfigParser | |
class Config(object): | |
def __init__(self, filename, section): | |
if not(os.path.exists(filename)): | |
raise ValueError("{} does not exist".format(filename)) | |
parser = ConfigParser() | |
parser.read(filename) | |
config = parser.items(section) | |
config = dict(config) | |
for key, item in config.items(): | |
config[key] = eval(item) | |
#set params as attributes | |
self.__dict__ = config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import numpy as np | |
from naivebayes import NaiveBayes | |
from twitter_reader import get_tweets | |
class CrossValidation(object): | |
def __init__(self): | |
self.classifier = NaiveBayes() | |
def create_data(self, user_ids): | |
data = [] | |
for category, ids in user_ids.items(): | |
tweets = get_tweets(ids) | |
categories = [category] * len(tweets) | |
data += list(zip(tweets, categories)) | |
np.random.shuffle(data) | |
return data | |
def split(self, data, test_percentage): | |
n_test = int(len(data)*test_percentage) | |
n_training = len(data)-n_test | |
# unzip (inverse of zip) | |
training = zip(*data[:n_training]) | |
test = zip(*data[n_training:]) | |
return training, test | |
def show_tweets_with_labels(self, tweets, labels): | |
for tweet, label in zip(tweets, labels): | |
print("{}:\n{}\n".format(label, tweet)) | |
def evaluate(self, user_ids, test_percentage=0.2, verbose=True): | |
""" | |
user_ids: Twitter IDs separated into categories. | |
test_percentage: Ratio of the amount of test data extracted | |
from tweets. | |
""" | |
if not(0 <= test_percentage <= 1): | |
raise ValueError("test_percentage must be between 0 and 1 " | |
"(inclusive).") | |
data = self.create_data(user_ids) | |
training, test = self.split(data, test_percentage) | |
tweets, categories = training | |
self.classifier.fit(tweets, categories) | |
tweets, answers = test | |
results = self.classifier.predict(tweets) | |
if(verbose): | |
self.show_tweets_with_labels(tweets, results) | |
return results, answers | |
class ClassificationResultEvaluator(object): | |
def __init__(self, results, answers): | |
self.labels = list(np.unique(answers)) | |
self.results = np.asarray(results) | |
self.answers = np.asarray(answers) | |
def count_n_true_positives(self, target_label): | |
n_true_positives = 0 | |
for result, answer in zip(self.results, self.answers): | |
if(result == answer == target_label): | |
n_true_positives += 1 | |
return n_true_positives | |
def calc_accuracy(self): | |
n_correct_answers = np.count_nonzero(self.results == self.answers) | |
return n_correct_answers / len(self.answers) | |
def calc_precision_recall_fmeasure(self, target_label): | |
true_positive = self.count_n_true_positives(target_label) | |
n_retrieved = np.count_nonzero(self.results == target_label) | |
n_relevant = np.count_nonzero(self.answers == target_label) | |
precision = true_positive / n_retrieved | |
recall = true_positive / n_relevant | |
fmeasure = 2 * precision * recall / (precision + recall) | |
return precision, recall, fmeasure | |
def report(self): | |
max_label_length = max(map(len, self.labels)) | |
white_spaces = ' ' * max_label_length | |
print("{} precision recall fmeasure".format(white_spaces)) | |
format_ = "{:" + str(max_label_length) + "} "\ | |
"{:.3f} {:.3f} {:.3f}" | |
for target_label in self.labels: | |
t = self.calc_precision_recall_fmeasure(target_label) | |
precision, recall, fmeasure = t | |
print(format_.format(target_label, precision, recall, fmeasure)) | |
print("\naccuracy: {}\n".format(self.calc_accuracy())) | |
if(__name__ == '__main__'): | |
from config import Config | |
config = Config('settings.cfg', 'TWITTER') | |
validator = CrossValidation() | |
results, answers = validator.evaluate({ | |
config.true_target_name: config.true_accounts, | |
config.false_target_name: config.false_accounts | |
}) | |
evaluator = ClassificationResultEvaluator(results, answers) | |
evaluator.calc_accuracy() | |
evaluator.report() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
from __future__ import division | |
import json | |
import math | |
import sys | |
import splitter | |
class NaiveBayes(object): | |
def __init__(self): | |
self.vocabulary = set() # 出現した単語の種類全部 | |
self.word_count = {} # {category1: {word1: 4, word2: 2,...}, ... } | |
self.alpha = 0.01 # 加算スムージングのパラメータ | |
def count_word(self, word, category): | |
""" | |
ある単語が出現した回数をカテゴリごとに数える。 | |
たとえば、政治カテゴリの中で「内閣」という単語が出現したら、 | |
政治カテゴリの「内閣」の出現回数を1増やす。 | |
""" | |
self.word_count.setdefault(category, {}) | |
self.word_count[category].setdefault(word, 0) | |
self.word_count[category][word] += 1 | |
self.vocabulary.add(word) | |
def fit(self, sentences, categories): | |
assert(len(sentences) == len(categories)) | |
for sentence, category in zip(sentences, categories): | |
self.fit_(sentence, category) | |
def fit_(self, sentence, category): | |
words = splitter.split(sentence) | |
for word in words: | |
self.count_word(word, category) | |
def calc_category_frequency(self, category): | |
""" | |
カテゴリが`category`である文書が入力された文書全体のうち | |
どれだけの割合を占めるかを計算する。 | |
""" | |
# TODO not necessary to calc frequency at every time of calling | |
# 学習データ内に含まれる全ての単語の数 | |
n_total_words = 0 | |
for category in self.word_count.keys(): | |
n_total_words += sum(self.word_count[category].values()) | |
# あるカテゴリ内の単語の総数 | |
n_words_in_category = sum(self.word_count[category].values()) | |
return n_words_in_category / n_total_words | |
def calc_word_frequency(self, word, category): | |
""" | |
カテゴリ`category`内で単語`word`が出現する確率 | |
P(`word`|`category`) を計算する。 | |
""" | |
# カテゴリが`category`である文書全体の中での単語`word`の出現回数 | |
# `word`が文書内に存在しない場合は0 | |
word_occurences = self.word_count[category].get(word, 0) | |
# カテゴリ`category`内の全単語数 | |
# 同じ単語が`category`内に複数回現れる場合はその回数ぶんを加算する。 | |
n_words_in_category = sum(self.word_count[category].values()) | |
# 単語の種類の数 | |
V = len(self.vocabulary) | |
# 本来は | |
# ```probability = word_occurences/n_words_in_category``` | |
# で計算できるが、`category`内に存在しない単語が入力されると | |
# `probability`が0になってしまうため、 | |
# ```probability = (word_occurences + 1)/ (n_words_in_category + V)``` | |
# としている(ゼロ頻度問題)。 | |
probability = (word_occurences + 1) / (n_words_in_category + V) | |
return probability | |
def calc_score(self, words, category): | |
# P(category|sentence) is proportional to | |
# P(category)P(sentence| category) | |
# = P(category)prod_i{P(word_i|category)} | |
# log(P(category|sentence)) is proportional to | |
# log(P(category)) + sum_i{log(P(word_i|category))} | |
# log(P(category)) | |
score = math.log(self.calc_category_frequency(category)) | |
# sum_i{log(P(word_i|category))} | |
for word in words: | |
score += math.log(self.calc_word_frequency(word, category)) | |
return score | |
def predict(self, sentences): | |
categories = [] | |
for sentence in sentences: | |
category = self.predict_(sentence) | |
categories.append(category) | |
return categories | |
def predict_(self, sentence): | |
# arg max log(P(category| sentence)) | |
best_suggested_category = None | |
max_probability = -float('inf') | |
words = splitter.split(sentence) | |
for category in self.word_count.keys(): | |
probability = self.calc_score(words, category) | |
if(probability > max_probability): | |
max_probability = probability | |
best_suggested_category = category | |
return best_suggested_category | |
def dump_json(self, filename): | |
attributes = self.__dict__ | |
# setはそのままdumpできないのでlistに変換 | |
attributes['vocabulary'] = list(attributes['vocabulary']) | |
json.dump(attributes, open(filename, 'w')) | |
def load_json(self, filename): | |
attributes = json.load(open(filename, 'r')) | |
attributes['vocabulary'] = set(attributes['vocabulary']) | |
self.__dict__ = attributes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[YAHOO] | |
# obtain from http://developer.yahoo.co.jp/webapi/jlp/ma/v1/parse.html | |
appid = 'your yahoo app id' | |
[TWITTER] | |
# obtain from https://dev.twitter.com/ | |
consumer_key = 'your consumer key'' | |
consumer_secret = 'your consumer secret' | |
token = 'your access token' | |
token_secret = 'your access token secret' | |
true_target_name = '中二病' | |
true_accounts = 'true_account_id1','true_account_id2' | |
false_target_name = '中二病ではない' | |
false_accounts = 'false_account_id1','false_account_id2' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
from urllib.parse import urlencode | |
from urllib.request import urlopen | |
from bs4 import BeautifulSoup | |
from config import Config | |
config = Config('settings.cfg', 'YAHOO') | |
pageurl = "http://jlp.yahooapis.jp/MAService/V1/parse" | |
results = "ma" | |
filter_ = "1|2|3|4|5|9|10" | |
def split(sentence): | |
params = urlencode({'appid': config.appid, | |
'results': results, | |
'filter': filter_, | |
'sentence': sentence}) | |
params = bytes(params, encoding='utf-8') | |
responce = urlopen(pageurl, params) | |
soup = BeautifulSoup(responce.read(), "lxml") | |
return [w.surface.string for w in soup.ma_result.word_list] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import codecs | |
import twitter | |
from config import Config | |
config = Config('settings.cfg', 'TWITTER') | |
auth = twitter.OAuth( | |
consumer_key=config.consumer_key, | |
consumer_secret=config.consumer_secret, | |
token=config.token, | |
token_secret=config.token_secret | |
) | |
api = twitter.Twitter(auth=auth) | |
def get_tweets(accounts): | |
"""Get recent 200 tweets from each account""" | |
def get_tweets_(screen_name): | |
statuses = api.statuses.user_timeline(screen_name=screen_name, | |
count=200) | |
return [s['text'] for s in statuses] | |
tweets = [] | |
for account in accounts: | |
tweets += get_tweets_(account) | |
return tweets |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment