Skip to content

Instantly share code, notes, and snippets.

@kzinmr
Last active January 25, 2023 13:07
Show Gist options
  • Save kzinmr/cf62e4411dc99df6128d9bf3a8688dd6 to your computer and use it in GitHub Desktop.
Save kzinmr/cf62e4411dc99df6128d9bf3a8688dd6 to your computer and use it in GitHub Desktop.
import multiprocessing
import kenlm
import MeCab
lm_model = None
def to_punct_between_texts(
text_l, text_r, l_model, tokenizer, threshold=0.0, punct="。", print_score=False,
):
""" 2つの文字列の間に句読点を打点するか否か判別するモジュール """
words_l, words_r = (
tokenizer.wakati(text_l).split(" "),
tokenizer.wakati(text_r).split(" "),
)
text = " ".join(words_l + words_r)
score = l_model.score(text)
text_punct = " ".join(words_l + [punct] + words_r)
score_p = l_model.score(text_punct)
gain = score_p - score
gain_r = gain / abs(score)
if print_score:
print(
f"Score:{score:.03f}, Score_punct:{score_p:.03f}, Gain_After_punct:{gain:.03f}, Gain_Ratio:{gain_r:.03f}"
)
return True if gain_r > threshold else False
def _process(args):
i, words, window, punct, score, threshold = args
# global lm_model の挙動を制御するのが面倒なため to_punct_between_texts と処理を共通化しない
text_punct = " ".join(words[: i + window] + [punct] + words[i + window :])
score_p = lm_model.score(text_punct)
gain_r = (score_p - score) / abs(score)
return i + window if gain_r > threshold else None
class Punctuator:
""" 句読点なしの文字列に句読点を打点するモジュール """
def __init__(self, model, window, tokenizer, puncts=None, threshold=0.0):
global lm_model
lm_model = model
self.window = window
if puncts is None:
self.puncts = ["。"]
else:
self.puncts = puncts
self.tokenizer = tokenizer
self.threshold = threshold
def wakati(self, t):
return self.tokenizer.wakati(t)
def __punct_insert_position(self, sentence, punct="。"):
"""
句読点を打つことで言語モデルスコアが上昇する場合に、句読点位置を追加
"""
window = self.window
words = sentence.split(" ")
punct_pos = []
score = lm_model.score(" ".join(words))
g = [
(i, words, window, punct, score, self.threshold)
for i in range(len(words) - window + 1)
]
with multiprocessing.Pool(multiprocessing.cpu_count()) as p:
punct_pos.append(p.map(_process, g))
return list(filter(lambda x: x is not None, punct_pos[0]))
@staticmethod
def __punct_decode(sentence, punct_pos, punct="。"):
"""
分かち書きされた文に、与えられた句読点位置に従い句読点を打つ
"""
words = sentence.split(" ")
punct_sent = []
for i in range(len(words)):
punct_sent.append(words[i])
if i + 1 in punct_pos:
punct_sent.append(punct)
return " ".join(punct_sent)
def punct_insert(self, sentence, wakati=False):
"""
文を分かち書きして、各単語の後に句読点追加するか否かを言語モデルスコアで判定
"""
sentence_w = self.wakati(sentence)
punct_position = self.__punct_insert_position(sentence_w, punct="。")
sentence_pred_w = self.__punct_decode(sentence_w, punct_position, punct="。")
if wakati:
return sentence_pred_w
else:
return "".join(sentence_pred_w.split(" "))
class MeCabWrapper:
def __init__(self):
self.wakati = MeCab.Tagger("-Owakati")
def wakati(self, text: str):
return self.wakati.parse(text).split(" ")
def main():
LM = 'model.100K.klm'
tokenizer = MeCabWrapper()
m = kenlm.Model(LM)
pt = Punctuator(m, m.order, tokenizer)
# 2文の間に句読点を打点するか否か
test_l, test_r = '今日もいい天気ですね。明日の天気も晴れとのこと', '控えめに言って最高'
assert(to_punct_between_texts(test_l, test_r, self.m, tokenizer, print_score=True))
# 句読点なしの文字列に句読点を打点
test_t = (test_l + test_r).replace('。', '')
assert('今日もいい天気ですね。明日の天気も晴れとのこと。控えめに言って最高。' == self.pt.punct_insert(test_t))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment