Created
May 14, 2020 04:35
-
-
Save kzinmr/ccfb678a545d254bf6ef8cf4c8dddb4b to your computer and use it in GitHub Desktop.
A simple bidirectional language model with nltk
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Dict, Set, Optional | |
from nltk.lm import MLE | |
from nltk.util import ngrams | |
class InvalidOrderException(Exception): | |
pass | |
class InvalidContextSizeException(Exception): | |
pass | |
class NotInVocabException(Exception): | |
pass | |
class BidirectionalLanguageModel: | |
def __init__( | |
self, | |
order: int, | |
): | |
if order > 1: | |
self.order = order | |
self.lm = MLE(order) | |
self.lm_rev = MLE(order) | |
else: # unigram はインターフェースの都合上非対応 | |
raise InvalidOrderException | |
def fit( | |
self, | |
documents: List[List[str]], | |
vocabulary: Optional[Set[str]] = None | |
): | |
if vocabulary is None: | |
vocabulary = {token for doc in documents for token in doc} | |
# FIXME: pad_left=True, pad_right=True, left_pad_symbol='<s>', right_pad_symbol='</s>' | |
train = [list(ngrams(doc, self.order)) for doc in documents] | |
train_rev = [[ngram[::-1] for ngram in ngrams_doc][::-1] for ngrams_doc in train] | |
self.lm.fit(train, vocabulary) | |
self.lm_rev.fit(train_rev, vocabulary) | |
self.train_ = train | |
self.train_rev_ = train_rev | |
self.vocabulary = vocabulary | |
# for model inspection | |
self.context2targets = self.lm.counts[self.order] | |
self.target2contexts = {target: [] for target in vocabulary} | |
for context, targets_distr in self.context2targets.items(): | |
for target, freq in targets_distr.items(): | |
self.target2contexts[target].append(context) | |
self.context2targets_rev = self.lm_rev.counts[self.order] | |
self.target2contexts_rev = {target: [] for target in vocabulary} | |
for context, targets_distr in self.context2targets_rev.items(): | |
for target, freq in targets_distr.items(): | |
self.target2contexts_rev[target].append(context) | |
def score( | |
self, | |
target: str, | |
context: List[str] | |
) -> Dict[str, float]: | |
if len(context) == self.order - 1: | |
return { | |
'forward': self.lm.score(target, context), | |
'backward': self.lm_rev.score(target, context) | |
} | |
else: | |
raise InvalidContextSizeException | |
def get_contexts( | |
self, | |
target: str | |
) -> Dict[str, List]: | |
if target not in self.vocabulary: | |
raise NotInVocabException | |
contexts = self.target2contexts[target] if target in self.target2contexts else {} | |
contexts_rev = self.target2contexts_rev[target] if target in self.target2contexts_rev else {} | |
return { | |
'forward': contexts, | |
'backward': contexts_rev | |
} | |
def get_targets( | |
self, | |
context: List[str] | |
) -> Dict[str, List]: | |
if len(context) != self.order: | |
raise InvalidContextSizeException | |
targets = self.context2targets[context] if context in self.context2targets else {} | |
targets_rev = self.context2targets_rev[context] if context in self.context2targets_rev else {} | |
return { | |
'forward': targets, | |
'backward': targets_rev | |
} | |
def generate( | |
self, | |
context: List[str] | |
): | |
return { | |
'forward': self.lm.generate(text_seed=context), | |
'backward': self.lm_rev.generate(text_seed=context) | |
} | |
def save(self, path): | |
pass | |
def load(self, path): | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment