Skip to content

Instantly share code, notes, and snippets.

@kzinmr
Created May 14, 2020 04:35
Show Gist options
  • Save kzinmr/ccfb678a545d254bf6ef8cf4c8dddb4b to your computer and use it in GitHub Desktop.
Save kzinmr/ccfb678a545d254bf6ef8cf4c8dddb4b to your computer and use it in GitHub Desktop.
A simple bidirectional language model with nltk
from typing import List, Dict, Set, Optional
from nltk.lm import MLE
from nltk.util import ngrams
class InvalidOrderException(Exception):
pass
class InvalidContextSizeException(Exception):
pass
class NotInVocabException(Exception):
pass
class BidirectionalLanguageModel:
def __init__(
self,
order: int,
):
if order > 1:
self.order = order
self.lm = MLE(order)
self.lm_rev = MLE(order)
else: # unigram はインターフェースの都合上非対応
raise InvalidOrderException
def fit(
self,
documents: List[List[str]],
vocabulary: Optional[Set[str]] = None
):
if vocabulary is None:
vocabulary = {token for doc in documents for token in doc}
# FIXME: pad_left=True, pad_right=True, left_pad_symbol='<s>', right_pad_symbol='</s>'
train = [list(ngrams(doc, self.order)) for doc in documents]
train_rev = [[ngram[::-1] for ngram in ngrams_doc][::-1] for ngrams_doc in train]
self.lm.fit(train, vocabulary)
self.lm_rev.fit(train_rev, vocabulary)
self.train_ = train
self.train_rev_ = train_rev
self.vocabulary = vocabulary
# for model inspection
self.context2targets = self.lm.counts[self.order]
self.target2contexts = {target: [] for target in vocabulary}
for context, targets_distr in self.context2targets.items():
for target, freq in targets_distr.items():
self.target2contexts[target].append(context)
self.context2targets_rev = self.lm_rev.counts[self.order]
self.target2contexts_rev = {target: [] for target in vocabulary}
for context, targets_distr in self.context2targets_rev.items():
for target, freq in targets_distr.items():
self.target2contexts_rev[target].append(context)
def score(
self,
target: str,
context: List[str]
) -> Dict[str, float]:
if len(context) == self.order - 1:
return {
'forward': self.lm.score(target, context),
'backward': self.lm_rev.score(target, context)
}
else:
raise InvalidContextSizeException
def get_contexts(
self,
target: str
) -> Dict[str, List]:
if target not in self.vocabulary:
raise NotInVocabException
contexts = self.target2contexts[target] if target in self.target2contexts else {}
contexts_rev = self.target2contexts_rev[target] if target in self.target2contexts_rev else {}
return {
'forward': contexts,
'backward': contexts_rev
}
def get_targets(
self,
context: List[str]
) -> Dict[str, List]:
if len(context) != self.order:
raise InvalidContextSizeException
targets = self.context2targets[context] if context in self.context2targets else {}
targets_rev = self.context2targets_rev[context] if context in self.context2targets_rev else {}
return {
'forward': targets,
'backward': targets_rev
}
def generate(
self,
context: List[str]
):
return {
'forward': self.lm.generate(text_seed=context),
'backward': self.lm_rev.generate(text_seed=context)
}
def save(self, path):
pass
def load(self, path):
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment