Last active
January 25, 2023 13:07
-
-
Save kzinmr/cf62e4411dc99df6128d9bf3a8688dd6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import multiprocessing | |
import kenlm | |
import MeCab | |
lm_model = None | |
def to_punct_between_texts( | |
text_l, text_r, l_model, tokenizer, threshold=0.0, punct="。", print_score=False, | |
): | |
""" 2つの文字列の間に句読点を打点するか否か判別するモジュール """ | |
words_l, words_r = ( | |
tokenizer.wakati(text_l).split(" "), | |
tokenizer.wakati(text_r).split(" "), | |
) | |
text = " ".join(words_l + words_r) | |
score = l_model.score(text) | |
text_punct = " ".join(words_l + [punct] + words_r) | |
score_p = l_model.score(text_punct) | |
gain = score_p - score | |
gain_r = gain / abs(score) | |
if print_score: | |
print( | |
f"Score:{score:.03f}, Score_punct:{score_p:.03f}, Gain_After_punct:{gain:.03f}, Gain_Ratio:{gain_r:.03f}" | |
) | |
return True if gain_r > threshold else False | |
def _process(args): | |
i, words, window, punct, score, threshold = args | |
# global lm_model の挙動を制御するのが面倒なため to_punct_between_texts と処理を共通化しない | |
text_punct = " ".join(words[: i + window] + [punct] + words[i + window :]) | |
score_p = lm_model.score(text_punct) | |
gain_r = (score_p - score) / abs(score) | |
return i + window if gain_r > threshold else None | |
class Punctuator: | |
""" 句読点なしの文字列に句読点を打点するモジュール """ | |
def __init__(self, model, window, tokenizer, puncts=None, threshold=0.0): | |
global lm_model | |
lm_model = model | |
self.window = window | |
if puncts is None: | |
self.puncts = ["。"] | |
else: | |
self.puncts = puncts | |
self.tokenizer = tokenizer | |
self.threshold = threshold | |
def wakati(self, t): | |
return self.tokenizer.wakati(t) | |
def __punct_insert_position(self, sentence, punct="。"): | |
""" | |
句読点を打つことで言語モデルスコアが上昇する場合に、句読点位置を追加 | |
""" | |
window = self.window | |
words = sentence.split(" ") | |
punct_pos = [] | |
score = lm_model.score(" ".join(words)) | |
g = [ | |
(i, words, window, punct, score, self.threshold) | |
for i in range(len(words) - window + 1) | |
] | |
with multiprocessing.Pool(multiprocessing.cpu_count()) as p: | |
punct_pos.append(p.map(_process, g)) | |
return list(filter(lambda x: x is not None, punct_pos[0])) | |
@staticmethod | |
def __punct_decode(sentence, punct_pos, punct="。"): | |
""" | |
分かち書きされた文に、与えられた句読点位置に従い句読点を打つ | |
""" | |
words = sentence.split(" ") | |
punct_sent = [] | |
for i in range(len(words)): | |
punct_sent.append(words[i]) | |
if i + 1 in punct_pos: | |
punct_sent.append(punct) | |
return " ".join(punct_sent) | |
def punct_insert(self, sentence, wakati=False): | |
""" | |
文を分かち書きして、各単語の後に句読点追加するか否かを言語モデルスコアで判定 | |
""" | |
sentence_w = self.wakati(sentence) | |
punct_position = self.__punct_insert_position(sentence_w, punct="。") | |
sentence_pred_w = self.__punct_decode(sentence_w, punct_position, punct="。") | |
if wakati: | |
return sentence_pred_w | |
else: | |
return "".join(sentence_pred_w.split(" ")) | |
class MeCabWrapper: | |
def __init__(self): | |
self.wakati = MeCab.Tagger("-Owakati") | |
def wakati(self, text: str): | |
return self.wakati.parse(text).split(" ") | |
def main(): | |
LM = 'model.100K.klm' | |
tokenizer = MeCabWrapper() | |
m = kenlm.Model(LM) | |
pt = Punctuator(m, m.order, tokenizer) | |
# 2文の間に句読点を打点するか否か | |
test_l, test_r = '今日もいい天気ですね。明日の天気も晴れとのこと', '控えめに言って最高' | |
assert(to_punct_between_texts(test_l, test_r, self.m, tokenizer, print_score=True)) | |
# 句読点なしの文字列に句読点を打点 | |
test_t = (test_l + test_r).replace('。', '') | |
assert('今日もいい天気ですね。明日の天気も晴れとのこと。控えめに言って最高。' == self.pt.punct_insert(test_t)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
https://drive.google.com/file/d/1zaQWbqHOABlSspJmGLrcnc59MlYQawP4/view?usp=share_link