Created
March 17, 2016 03:56
-
-
Save naoyashiga/4dfaa7e2a5222a9cadd9 to your computer and use it in GitHub Desktop.
TextGeneratorをPython3系で動かす
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
u""" | |
マルコフ連鎖を用いて適当な文章を自動生成するファイル | |
""" | |
import os.path | |
import sqlite3 | |
import random | |
from PrepareChain import PrepareChain | |
class GenerateText(object): | |
u""" | |
文章生成用クラス | |
""" | |
def __init__(self, n=5): | |
u""" | |
初期化メソッド | |
@param n いくつの文章を生成するか | |
""" | |
self.n = n | |
def generate(self): | |
u""" | |
実際に生成する | |
@return 生成された文章 | |
""" | |
# DBが存在しないときは例外をあげる | |
if not os.path.exists(PrepareChain.DB_PATH): | |
raise IOError(u"DBファイルが存在しません") | |
# DBオープン | |
con = sqlite3.connect(PrepareChain.DB_PATH) | |
con.row_factory = sqlite3.Row | |
# 最終的にできる文章 | |
generated_text = u"" | |
# 指定の数だけ作成する | |
# for i in xrange(self.n): | |
for i in range(self.n): | |
text = self._generate_sentence(con) | |
generated_text += text | |
# DBクローズ | |
con.close() | |
return generated_text | |
def _generate_sentence(self, con): | |
u""" | |
ランダムに一文を生成する | |
@param con DBコネクション | |
@return 生成された1つの文章 | |
""" | |
# 生成文章のリスト | |
morphemes = [] | |
# はじまりを取得 | |
first_triplet = self._get_first_triplet(con) | |
morphemes.append(first_triplet[1]) | |
morphemes.append(first_triplet[2]) | |
# 文章を紡いでいく | |
while morphemes[-1] != PrepareChain.END: | |
prefix1 = morphemes[-2] | |
prefix2 = morphemes[-1] | |
triplet = self._get_triplet(con, prefix1, prefix2) | |
morphemes.append(triplet[2]) | |
# 連結 | |
result = "".join(morphemes[:-1]) | |
return result | |
def _get_chain_from_DB(self, con, prefixes): | |
u""" | |
チェーンの情報をDBから取得する | |
@param con DBコネクション | |
@param prefixes チェーンを取得するprefixの条件 tupleかlist | |
@return チェーンの情報の配列 | |
""" | |
# ベースとなるSQL | |
sql = u"select prefix1, prefix2, suffix, freq from chain_freqs where prefix1 = ?" | |
# prefixが2つなら条件に加える | |
if len(prefixes) == 2: | |
sql += u" and prefix2 = ?" | |
# 結果 | |
result = [] | |
# DBから取得 | |
cursor = con.execute(sql, prefixes) | |
for row in cursor: | |
result.append(dict(row)) | |
return result | |
def _get_first_triplet(self, con): | |
u""" | |
文章のはじまりの3つ組をランダムに取得する | |
@param con DBコネクション | |
@return 文章のはじまりの3つ組のタプル | |
""" | |
# BEGINをprefix1としてチェーンを取得 | |
prefixes = (PrepareChain.BEGIN,) | |
# チェーン情報を取得 | |
chains = self._get_chain_from_DB(con, prefixes) | |
# 取得したチェーンから、確率的に1つ選ぶ | |
triplet = self._get_probable_triplet(chains) | |
return (triplet["prefix1"], triplet["prefix2"], triplet["suffix"]) | |
def _get_triplet(self, con, prefix1, prefix2): | |
u""" | |
prefix1とprefix2からsuffixをランダムに取得する | |
@param con DBコネクション | |
@param prefix1 1つ目のprefix | |
@param prefix2 2つ目のprefix | |
@return 3つ組のタプル | |
""" | |
# BEGINをprefix1としてチェーンを取得 | |
prefixes = (prefix1, prefix2) | |
# チェーン情報を取得 | |
chains = self._get_chain_from_DB(con, prefixes) | |
# 取得したチェーンから、確率的に1つ選ぶ | |
triplet = self._get_probable_triplet(chains) | |
return (triplet["prefix1"], triplet["prefix2"], triplet["suffix"]) | |
def _get_probable_triplet(self, chains): | |
u""" | |
チェーンの配列の中から確率的に1つを返す | |
@param chains チェーンの配列 | |
@return 確率的に選んだ3つ組 | |
""" | |
# 確率配列 | |
probability = [] | |
# 確率に合うように、インデックスを入れる | |
for (index, chain) in enumerate(chains): | |
# for j in xrange(chain["freq"]): | |
for j in range(chain["freq"]): | |
probability.append(index) | |
# ランダムに1つを選ぶ | |
chain_index = random.choice(probability) | |
return chains[chain_index] | |
if __name__ == '__main__': | |
generator = GenerateText() | |
print(generator.generate()) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
u""" | |
与えられた文書からマルコフ連鎖のためのチェーン(連鎖)を作成して、DBに保存するファイル | |
""" | |
import unittest | |
import re | |
import MeCab | |
import sqlite3 | |
from collections import defaultdict | |
class PrepareChain(object): | |
u""" | |
チェーンを作成してDBに保存するクラス | |
""" | |
BEGIN = u"__BEGIN_SENTENCE__" | |
END = u"__END_SENTENCE__" | |
DB_PATH = "chain.db" | |
DB_SCHEMA_PATH = "schema.sql" | |
def __init__(self, text): | |
u""" | |
初期化メソッド | |
@param text チェーンを生成するための文章 | |
""" | |
# if isinstance(text, str): | |
# text = text.decode("utf-8") | |
self.text = text | |
# 形態素解析用タガー | |
self.tagger = MeCab.Tagger('-Ochasen') | |
def make_triplet_freqs(self): | |
u""" | |
形態素解析から3つ組の出現回数まで | |
@return 3つ組とその出現回数の辞書 key: 3つ組(タプル) val: 出現回数 | |
""" | |
# 長い文章をセンテンス毎に分割 | |
sentences = self._divide(self.text) | |
# 3つ組の出現回数 | |
triplet_freqs = defaultdict(int) | |
# センテンス毎に3つ組にする | |
for sentence in sentences: | |
# 形態素解析 | |
morphemes = self._morphological_analysis(sentence) | |
# 3つ組をつくる | |
triplets = self._make_triplet(morphemes) | |
# 出現回数を加算 | |
for (triplet, n) in triplets.items(): | |
triplet_freqs[triplet] += n | |
return triplet_freqs | |
def _divide(self, text): | |
u""" | |
「。」や改行などで区切られた長い文章を一文ずつに分ける | |
@param text 分割前の文章 | |
@return 一文ずつの配列 | |
""" | |
# 改行文字以外の分割文字(正規表現表記) | |
delimiter = u"。|.|\." | |
# 全ての分割文字を改行文字に置換(splitしたときに「。」などの情報を無くさないため) | |
text = re.sub(r"({0})".format(delimiter), r"\1\n", text) | |
# 改行文字で分割 | |
sentences = text.splitlines() | |
# 前後の空白文字を削除 | |
sentences = [sentence.strip() for sentence in sentences] | |
return sentences | |
def _morphological_analysis(self, sentence): | |
u""" | |
一文を形態素解析する | |
@param sentence 一文 | |
@return 形態素で分割された配列 | |
""" | |
morphemes = [] | |
# sentence = sentence.encode("utf-8") | |
node = self.tagger.parseToNode(sentence) | |
while node: | |
if node.posid != 0: | |
# morpheme = node.surface.decode("utf-8") | |
morpheme = node.surface | |
morphemes.append(morpheme) | |
node = node.next | |
return morphemes | |
def _make_triplet(self, morphemes): | |
u""" | |
形態素解析で分割された配列を、形態素毎に3つ組にしてその出現回数を数える | |
@param morphemes 形態素配列 | |
@return 3つ組とその出現回数の辞書 key: 3つ組(タプル) val: 出現回数 | |
""" | |
# 3つ組をつくれない場合は終える | |
if len(morphemes) < 3: | |
return {} | |
# 出現回数の辞書 | |
triplet_freqs = defaultdict(int) | |
# 繰り返し | |
# for i in xrange(len(morphemes)-2): | |
for i in range(len(morphemes)-2): | |
triplet = tuple(morphemes[i:i+3]) | |
triplet_freqs[triplet] += 1 | |
# beginを追加 | |
triplet = (PrepareChain.BEGIN, morphemes[0], morphemes[1]) | |
triplet_freqs[triplet] = 1 | |
# endを追加 | |
triplet = (morphemes[-2], morphemes[-1], PrepareChain.END) | |
triplet_freqs[triplet] = 1 | |
return triplet_freqs | |
def save(self, triplet_freqs, init=False): | |
u""" | |
3つ組毎に出現回数をDBに保存 | |
@param triplet_freqs 3つ組とその出現回数の辞書 key: 3つ組(タプル) val: 出現回数 | |
""" | |
# DBオープン | |
con = sqlite3.connect(PrepareChain.DB_PATH) | |
# 初期化から始める場合 | |
if init: | |
# DBの初期化 | |
with open(PrepareChain.DB_SCHEMA_PATH, "r") as f: | |
schema = f.read() | |
con.executescript(schema) | |
# データ整形 | |
datas = [(triplet[0], triplet[1], triplet[2], freq) for (triplet, freq) in triplet_freqs.items()] | |
# データ挿入 | |
p_statement = u"insert into chain_freqs (prefix1, prefix2, suffix, freq) values (?, ?, ?, ?)" | |
con.executemany(p_statement, datas) | |
# コミットしてクローズ | |
con.commit() | |
con.close() | |
def show(self, triplet_freqs): | |
u""" | |
3つ組毎の出現回数を出力する | |
@param triplet_freqs 3つ組とその出現回数の辞書 key: 3つ組(タプル) val: 出現回数 | |
""" | |
for triplet in triplet_freqs: | |
print("|".join(triplet), "\t", triplet_freqs[triplet]) | |
class TestFunctions(unittest.TestCase): | |
u""" | |
テスト用クラス | |
""" | |
def setUp(self): | |
u""" | |
テストが実行される前に実行される | |
""" | |
self.text = u"こんにちは。 今日は、楽しい運動会です。hello world.我輩は猫である\n 名前はまだない。我輩は犬である\r\n名前は決まってるよ" | |
self.chain = PrepareChain(self.text) | |
def test_make_triplet_freqs(self): | |
u""" | |
全体のテスト | |
""" | |
triplet_freqs = self.chain.make_triplet_freqs() | |
answer = {(u"__BEGIN_SENTENCE__", u"今日", u"は"): 1, (u"今日", u"は", u"、"): 1, (u"は", u"、", u"楽しい"): 1, (u"、", u"楽しい", u"運動会"): 1, (u"楽しい", u"運動会", u"です"): 1, (u"運動会", u"です", u"。"): 1, (u"です", u"。", u"__END_SENTENCE__"): 1, (u"__BEGIN_SENTENCE__", u"hello", u"world"): 1, (u"hello", u"world", u"."): 1, (u"world", u".", u"__END_SENTENCE__"): 1, (u"__BEGIN_SENTENCE__", u"我輩", u"は"): 2, (u"我輩", u"は", u"猫"): 1, (u"は", u"猫", u"で"): 1, (u"猫", u"で", u"ある"): 1, (u"で", u"ある", u"__END_SENTENCE__"): 2, (u"__BEGIN_SENTENCE__", u"名前", u"は"): 2, (u"名前", u"は", u"まだ"): 1, (u"は", u"まだ", u"ない"): 1, (u"まだ", u"ない", u"。"): 1, (u"ない", u"。", u"__END_SENTENCE__"): 1, (u"我輩", u"は", u"犬"): 1, (u"は", u"犬", u"で"): 1, (u"犬", u"で", u"ある"): 1, (u"名前", u"は", u"決まっ"): 1, (u"は", u"決まっ", u"てる"): 1, (u"決まっ", u"てる", u"よ"): 1, (u"てる", u"よ", u"__END_SENTENCE__"): 1} | |
self.assertEqual(triplet_freqs, answer) | |
def test_divide(self): | |
u""" | |
一文ずつに分割するテスト | |
""" | |
sentences = self.chain._divide(self.text) | |
answer = [u"こんにちは。", u"今日は、楽しい運動会です。", u"hello world.", u"我輩は猫である", u"名前はまだない。", u"我輩は犬である", u"名前は決まってるよ"] | |
self.assertEqual(sentences.sort(), answer.sort()) | |
def test_morphological_analysis(self): | |
u""" | |
形態素解析用のテスト | |
""" | |
sentence = u"今日は、楽しい運動会です。" | |
morphemes = self.chain._morphological_analysis(sentence) | |
answer = [u"今日", u"は", u"、", u"楽しい", u"運動会", u"です", u"。"] | |
self.assertEqual(morphemes.sort(), answer.sort()) | |
def test_make_triplet(self): | |
u""" | |
形態素毎に3つ組にしてその出現回数を数えるテスト | |
""" | |
morphemes = [u"今日", u"は", u"、", u"楽しい", u"運動会", u"です", u"。"] | |
triplet_freqs = self.chain._make_triplet(morphemes) | |
answer = {(u"__BEGIN_SENTENCE__", u"今日", u"は"): 1, (u"今日", u"は", u"、"): 1, (u"は", u"、", u"楽しい"): 1, (u"、", u"楽しい", u"運動会"): 1, (u"楽しい", u"運動会", u"です"): 1, (u"運動会", u"です", u"。"): 1, (u"です", u"。", u"__END_SENTENCE__"): 1} | |
self.assertEqual(triplet_freqs, answer) | |
def test_make_triplet_too_short(self): | |
u""" | |
形態素毎に3つ組にしてその出現回数を数えるテスト | |
ただし、形態素が少なすぎる場合 | |
""" | |
morphemes = [u"こんにちは", u"。"] | |
triplet_freqs = self.chain._make_triplet(morphemes) | |
answer = {} | |
self.assertEqual(triplet_freqs, answer) | |
def test_make_triplet_3morphemes(self): | |
u""" | |
形態素毎に3つ組にしてその出現回数を数えるテスト | |
ただし、形態素がちょうど3つの場合 | |
""" | |
morphemes = [u"hello", u"world", u"."] | |
triplet_freqs = self.chain._make_triplet(morphemes) | |
answer = {(u"__BEGIN_SENTENCE__", u"hello", u"world"): 1, (u"hello", u"world", u"."): 1, (u"world", u".", u"__END_SENTENCE__"): 1} | |
self.assertEqual(triplet_freqs, answer) | |
def tearDown(self): | |
u""" | |
テストが実行された後に実行される | |
""" | |
pass | |
if __name__ == '__main__': | |
# unittest.main() | |
# テキストをファイルから読み込む | |
ero_file = open('../ero.txt', encoding='utf-8') | |
text = ero_file.read() | |
chain = PrepareChain(text) | |
triplet_freqs = chain.make_triplet_freqs() | |
chain.save(triplet_freqs, True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment