Created
August 4, 2018 15:05
-
-
Save sai-prasanna/4749d84aa2d077d76d71058e54556e64 to your computer and use it in GitHub Desktop.
SentencePiece AllenNLP
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import tempfile | |
import logging | |
import sentencepiece as spm | |
from typing import List | |
from overrides import overrides | |
from allennlp.common import Params | |
from allennlp.data.tokenizers.token import Token | |
from allennlp.data.tokenizers.tokenizer import Tokenizer | |
from allennlp.data.tokenizers.word_filter import WordFilter, PassThroughWordFilter | |
from allennlp.data.tokenizers.word_splitter import WordSplitter, JustSpacesWordSplitter | |
from allennlp.data.tokenizers.word_stemmer import WordStemmer, PassThroughWordStemmer | |
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | |
@Tokenizer.register("sentencepiece") | |
class SentencePieceTokenizer(Tokenizer): | |
""" | |
A ``WordTokenizer`` handles the splitting of strings into words as well as any desired | |
post-processing (e.g., stemming, filtering, etc.). Note that we leave one particular piece of | |
post-processing for later: the decision of whether or not to lowercase the token. This is for | |
two reasons: (1) if you want to make two different casing decisions for whatever reason, you | |
won't have to run the tokenizer twice, and more importantly (2) if you want to lowercase words | |
for your word embedding, but retain capitalization in a character-level representation, we need | |
to retain the capitalization here. | |
Parameters | |
---------- | |
word_splitter : ``WordSplitter``, optional | |
The :class:`WordSplitter` to use for splitting text strings into word tokens. The default | |
is to use the ``JustSpacesWordSplitter``, which is non destructive other than for spaces. | |
word_filter : ``WordFilter``, optional | |
The :class:`WordFilter` to use for, e.g., removing stopwords. Default is to do no | |
filtering. | |
word_stemmer : ``WordStemmer``, optional | |
The :class:`WordStemmer` to use. Default is no stemming. | |
start_tokens : ``List[str]``, optional | |
If given, these tokens will be added to the beginning of every string we tokenize. | |
end_tokens : ``List[str]``, optional | |
If given, these tokens will be added to the end of every string we tokenize. | |
""" | |
def __init__(self, | |
model_path: str, | |
word_splitter: WordSplitter = None, | |
word_filter: WordFilter = PassThroughWordFilter(), | |
word_stemmer: WordStemmer = PassThroughWordStemmer(), | |
start_tokens: List[str] = None, | |
end_tokens: List[str] = None, | |
vocab_size: int = 30000, | |
model_type: str = 'unigram') -> None: | |
self._word_splitter = word_splitter or JustSpacesWordSplitter() | |
self._word_filter = word_filter | |
self._word_stemmer = word_stemmer | |
self._start_tokens = start_tokens or [] | |
# We reverse the tokens here because we're going to insert them with `insert(0)` later; | |
# this makes sure they show up in the right order. | |
self._start_tokens.reverse() | |
self._end_tokens = end_tokens or [] | |
self._vocab_size = vocab_size | |
self._model_path = model_path | |
self._model = spm.SentencePieceProcessor() | |
self._model_type = model_type | |
self.trained = False | |
if os.path.exists(self._model_path): | |
self._model.Load(self._model_path) | |
self.trained = True | |
def train(self, texts: List[str]) -> None: | |
""" | |
Train the tokenizer subword model. | |
""" | |
if self.trained: | |
logger.warning("Tokenizer model already exists, skipping training.") | |
return | |
preprocessed_texts = self._batch_preprocess(texts) | |
with tempfile.NamedTemporaryFile('w', delete=False) as temp_fp: | |
# SentencePiece works in sentence level | |
temp_fp.write("\n".join(preprocessed_texts) + "\n") | |
model_prefix = 'subword' | |
spm.SentencePieceTrainer.Train(f' --hard_vocab_limit=false' | |
+ f' --input={temp_fp.name}' | |
+ f' --model_prefix={model_prefix}' | |
+ f' --vocab_size={self._vocab_size}' | |
+ f' --model_type={self._model_type}') | |
os.remove(temp_fp.name) | |
os.rename(f'{model_prefix}.model', self._model_path) | |
os.remove(f'{model_prefix}.vocab') # Discarding vocab as we will use Vocabulary | |
self._model.Load(self._model_path) | |
@overrides | |
def tokenize(self, text: str) -> List[Token]: | |
""" | |
Does whatever processing is required to convert a string of text into a sequence of tokens. | |
At a minimum, this uses a ``WordSplitter`` to split words into text. It may also do | |
stemming or stopword removal, depending on the parameters given to the constructor. | |
""" | |
words = self._word_splitter.split_words(text) | |
words = self._filter_and_stem(words) | |
preprocessed_text = " ".join(map(str, words)) | |
return self._tokenize(preprocessed_text) | |
@overrides | |
def batch_tokenize(self, texts: List[str]) -> List[List[Token]]: | |
preprocessed_texts = self._batch_preprocess(texts) | |
return [self._tokenize(text) for text in preprocessed_texts] | |
def _batch_preprocess(self, texts: List[str]) -> List[str]: | |
""" | |
Does word splitting, filtering, stemming on batch of sentences and returns joint sentences. | |
""" | |
batched_words = self._word_splitter.batch_split_words(texts) | |
batched_words = [self._filter_and_stem(words) for words in batched_words] | |
return [" ".join(map(str, words)) for words in batched_words] | |
def _filter_and_stem(self, words: List[Token]) -> List[Token]: | |
filtered_words = self._word_filter.filter_words(words) | |
stemmed_words = [self._word_stemmer.stem_word(word) for word in filtered_words] | |
return stemmed_words | |
def _tokenize(self, text: str) -> List[Token]: | |
str_tokens = [tok.decode('utf-8') for tok in self._model.EncodeAsPieces(text)] | |
tokens = [Token(token, i) for i, token in enumerate(str_tokens)] | |
for start_token in self._start_tokens: | |
tokens.insert(0, Token(start_token, 0)) | |
for end_token in self._end_tokens: | |
tokens.append(Token(end_token, -1)) | |
return tokens | |
@classmethod | |
def from_params(cls, params: Params) -> 'SentencePieceTokenizer': | |
model_path = params.pop('model_path', None) | |
word_splitter = WordSplitter.from_params(params.pop('word_splitter', {})) | |
word_filter = WordFilter.from_params(params.pop('word_filter', {})) | |
word_stemmer = WordStemmer.from_params(params.pop('word_stemmer', {})) | |
start_tokens = params.pop('start_tokens', None) | |
end_tokens = params.pop('end_tokens', None) | |
vocab_size = params.pop('vocab_size', 32000) | |
model_type = params.pop('model_type', 'unigram') | |
params.assert_empty(cls.__name__) | |
return cls(model_path=model_path, | |
word_splitter=word_splitter, | |
word_filter=word_filter, | |
word_stemmer=word_stemmer, | |
start_tokens=start_tokens, | |
end_tokens=end_tokens, | |
vocab_size=vocab_size, | |
model_type=model_type) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict | |
import logging | |
from my_library.tokenizers import SentencePieceTokenizer | |
from overrides import overrides | |
from allennlp.common import Params | |
from allennlp.common.checks import ConfigurationError | |
from allennlp.common.file_utils import cached_path | |
from allennlp.common.util import START_SYMBOL, END_SYMBOL | |
from allennlp.data.dataset_readers.dataset_reader import DatasetReader | |
from allennlp.data.fields import TextField | |
from allennlp.data.instance import Instance | |
from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer | |
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer | |
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | |
@DatasetReader.register("seq2seqnew") | |
class Seq2SeqNewDatasetReader(DatasetReader): | |
""" | |
Read a tsv file containing paired sequences, and create a dataset suitable for a | |
``SimpleSeq2Seq`` model, or any model with a matching API. | |
Expected format for each input line: <source_sequence_string>\t<target_sequence_string> | |
The output of ``read`` is a list of ``Instance`` s with the fields: | |
source_tokens: ``TextField`` and | |
target_tokens: ``TextField`` | |
`START_SYMBOL` and `END_SYMBOL` tokens are added to the source and target sequences. | |
Parameters | |
---------- | |
source_tokenizer : ``Tokenizer``, optional | |
Tokenizer to use to split the input sequences into words or other kinds of tokens. Defaults | |
to ``WordTokenizer()``. | |
target_tokenizer : ``Tokenizer``, optional | |
Tokenizer to use to split the output sequences (during training) into words or other kinds | |
of tokens. Defaults to ``source_tokenizer``. | |
source_token_indexers : ``Dict[str, TokenIndexer]``, optional | |
Indexers used to define input (source side) token representations. Defaults to | |
``{"tokens": SingleIdTokenIndexer()}``. | |
target_token_indexers : ``Dict[str, TokenIndexer]``, optional | |
Indexers used to define output (target side) token representations. Defaults to | |
``source_token_indexers``. | |
source_add_start_token : bool, (optional, default=True) | |
Whether or not to add `START_SYMBOL` to the beginning of the source sequence. | |
""" | |
def __init__(self, | |
source_tokenizer: Tokenizer = None, | |
target_tokenizer: Tokenizer = None, | |
source_token_indexers: Dict[str, TokenIndexer] = None, | |
target_token_indexers: Dict[str, TokenIndexer] = None, | |
source_add_start_token: bool = True, | |
lazy: bool = False) -> None: | |
super().__init__(lazy) | |
self._source_tokenizer = source_tokenizer or WordTokenizer() | |
self._target_tokenizer = target_tokenizer or self._source_tokenizer | |
self._source_token_indexers = source_token_indexers or {"tokens": SingleIdTokenIndexer()} | |
self._target_token_indexers = target_token_indexers or self._source_token_indexers | |
self._source_add_start_token = source_add_start_token | |
@overrides | |
def _read(self, file_path): | |
source_trained, target_trained = True, True | |
if isinstance(self._source_tokenizer, SentencePieceTokenizer): | |
source_trained = self._source_tokenizer.trained | |
if isinstance(self._target_tokenizer, SentencePieceTokenizer): | |
target_trained = self._target_tokenizer.trained | |
any_untrained = not (source_trained and target_trained) | |
if self.lazy and any_untrained: | |
raise ConfigurationError( | |
"Cannot run lazily without both source_tokenizer and target_tokenizer trained") | |
if any_untrained: | |
sources, targets = [], [] | |
with open(cached_path(file_path), "r") as data_file: | |
logger.info("Reading instances from lines in file at: %s", file_path) | |
for line_num, line in enumerate(data_file): | |
line = line.strip("\n") | |
if not line: | |
continue | |
line_parts = line.split('\t') | |
if len(line_parts) != 2: | |
raise ConfigurationError("Invalid line format: %s (line number %d)" % (line, line_num + 1)) | |
source_sequence, target_sequence = line_parts | |
if any_untrained: | |
sources.append(source_sequence) | |
targets.append(target_sequence) | |
else: | |
yield self.text_to_instance(source_sequence, target_sequence) | |
if any_untrained: | |
if self._source_tokenizer == self._target_tokenizer: | |
self._source_tokenizer.train(sources+targets) | |
else: | |
self._source_tokenizer.train(sources) | |
self._target_tokenizer.train(targets) | |
for source_sequence, target_sequence in zip(sources, targets): | |
yield self.text_to_instance(source_sequence, target_sequence) | |
@overrides | |
def text_to_instance(self, source_string: str, target_string: str = None) -> Instance: # type: ignore | |
# pylint: disable=arguments-differ | |
tokenized_source = self._source_tokenizer.tokenize(source_string) | |
if self._source_add_start_token: | |
tokenized_source.insert(0, Token(START_SYMBOL)) | |
tokenized_source.append(Token(END_SYMBOL)) | |
source_field = TextField(tokenized_source, self._source_token_indexers) | |
if target_string is not None: | |
tokenized_target = self._target_tokenizer.tokenize(target_string) | |
tokenized_target.insert(0, Token(START_SYMBOL)) | |
tokenized_target.append(Token(END_SYMBOL)) | |
target_field = TextField(tokenized_target, self._target_token_indexers) | |
return Instance({"source_tokens": source_field, "target_tokens": target_field}) | |
else: | |
return Instance({'source_tokens': source_field}) | |
@classmethod | |
def from_params(cls, params: Params) -> 'Seq2SeqNewDatasetReader': | |
source_tokenizer_type = params.pop('source_tokenizer', None) | |
source_tokenizer = None if source_tokenizer_type is None else Tokenizer.from_params(source_tokenizer_type) | |
target_tokenizer_type = params.pop('target_tokenizer', None) | |
target_tokenizer = None if target_tokenizer_type is None else Tokenizer.from_params(target_tokenizer_type) | |
source_indexers_type = params.pop('source_token_indexers', None) | |
source_add_start_token = params.pop_bool('source_add_start_token', True) | |
if source_indexers_type is None: | |
source_token_indexers = None | |
else: | |
source_token_indexers = TokenIndexer.dict_from_params(source_indexers_type) | |
target_indexers_type = params.pop('target_token_indexers', None) | |
if target_indexers_type is None: | |
target_token_indexers = None | |
else: | |
target_token_indexers = TokenIndexer.dict_from_params(target_indexers_type) | |
lazy = params.pop('lazy', False) | |
params.assert_empty(cls.__name__) | |
return Seq2SeqNewDatasetReader(source_tokenizer=source_tokenizer, | |
target_tokenizer=target_tokenizer, | |
source_token_indexers=source_token_indexers, | |
target_token_indexers=target_token_indexers, | |
source_add_start_token=source_add_start_token, | |
lazy=lazy) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment