Skip to content

Instantly share code, notes, and snippets.

@sai-prasanna
Created August 4, 2018 15:05
Show Gist options
  • Save sai-prasanna/4749d84aa2d077d76d71058e54556e64 to your computer and use it in GitHub Desktop.
Save sai-prasanna/4749d84aa2d077d76d71058e54556e64 to your computer and use it in GitHub Desktop.
SentencePiece AllenNLP
import os
import tempfile
import logging
import sentencepiece as spm
from typing import List
from overrides import overrides
from allennlp.common import Params
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers.tokenizer import Tokenizer
from allennlp.data.tokenizers.word_filter import WordFilter, PassThroughWordFilter
from allennlp.data.tokenizers.word_splitter import WordSplitter, JustSpacesWordSplitter
from allennlp.data.tokenizers.word_stemmer import WordStemmer, PassThroughWordStemmer
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@Tokenizer.register("sentencepiece")
class SentencePieceTokenizer(Tokenizer):
"""
A ``WordTokenizer`` handles the splitting of strings into words as well as any desired
post-processing (e.g., stemming, filtering, etc.). Note that we leave one particular piece of
post-processing for later: the decision of whether or not to lowercase the token. This is for
two reasons: (1) if you want to make two different casing decisions for whatever reason, you
won't have to run the tokenizer twice, and more importantly (2) if you want to lowercase words
for your word embedding, but retain capitalization in a character-level representation, we need
to retain the capitalization here.
Parameters
----------
word_splitter : ``WordSplitter``, optional
The :class:`WordSplitter` to use for splitting text strings into word tokens. The default
is to use the ``JustSpacesWordSplitter``, which is non destructive other than for spaces.
word_filter : ``WordFilter``, optional
The :class:`WordFilter` to use for, e.g., removing stopwords. Default is to do no
filtering.
word_stemmer : ``WordStemmer``, optional
The :class:`WordStemmer` to use. Default is no stemming.
start_tokens : ``List[str]``, optional
If given, these tokens will be added to the beginning of every string we tokenize.
end_tokens : ``List[str]``, optional
If given, these tokens will be added to the end of every string we tokenize.
"""
def __init__(self,
model_path: str,
word_splitter: WordSplitter = None,
word_filter: WordFilter = PassThroughWordFilter(),
word_stemmer: WordStemmer = PassThroughWordStemmer(),
start_tokens: List[str] = None,
end_tokens: List[str] = None,
vocab_size: int = 30000,
model_type: str = 'unigram') -> None:
self._word_splitter = word_splitter or JustSpacesWordSplitter()
self._word_filter = word_filter
self._word_stemmer = word_stemmer
self._start_tokens = start_tokens or []
# We reverse the tokens here because we're going to insert them with `insert(0)` later;
# this makes sure they show up in the right order.
self._start_tokens.reverse()
self._end_tokens = end_tokens or []
self._vocab_size = vocab_size
self._model_path = model_path
self._model = spm.SentencePieceProcessor()
self._model_type = model_type
self.trained = False
if os.path.exists(self._model_path):
self._model.Load(self._model_path)
self.trained = True
def train(self, texts: List[str]) -> None:
"""
Train the tokenizer subword model.
"""
if self.trained:
logger.warning("Tokenizer model already exists, skipping training.")
return
preprocessed_texts = self._batch_preprocess(texts)
with tempfile.NamedTemporaryFile('w', delete=False) as temp_fp:
# SentencePiece works in sentence level
temp_fp.write("\n".join(preprocessed_texts) + "\n")
model_prefix = 'subword'
spm.SentencePieceTrainer.Train(f' --hard_vocab_limit=false'
+ f' --input={temp_fp.name}'
+ f' --model_prefix={model_prefix}'
+ f' --vocab_size={self._vocab_size}'
+ f' --model_type={self._model_type}')
os.remove(temp_fp.name)
os.rename(f'{model_prefix}.model', self._model_path)
os.remove(f'{model_prefix}.vocab') # Discarding vocab as we will use Vocabulary
self._model.Load(self._model_path)
@overrides
def tokenize(self, text: str) -> List[Token]:
"""
Does whatever processing is required to convert a string of text into a sequence of tokens.
At a minimum, this uses a ``WordSplitter`` to split words into text. It may also do
stemming or stopword removal, depending on the parameters given to the constructor.
"""
words = self._word_splitter.split_words(text)
words = self._filter_and_stem(words)
preprocessed_text = " ".join(map(str, words))
return self._tokenize(preprocessed_text)
@overrides
def batch_tokenize(self, texts: List[str]) -> List[List[Token]]:
preprocessed_texts = self._batch_preprocess(texts)
return [self._tokenize(text) for text in preprocessed_texts]
def _batch_preprocess(self, texts: List[str]) -> List[str]:
"""
Does word splitting, filtering, stemming on batch of sentences and returns joint sentences.
"""
batched_words = self._word_splitter.batch_split_words(texts)
batched_words = [self._filter_and_stem(words) for words in batched_words]
return [" ".join(map(str, words)) for words in batched_words]
def _filter_and_stem(self, words: List[Token]) -> List[Token]:
filtered_words = self._word_filter.filter_words(words)
stemmed_words = [self._word_stemmer.stem_word(word) for word in filtered_words]
return stemmed_words
def _tokenize(self, text: str) -> List[Token]:
str_tokens = [tok.decode('utf-8') for tok in self._model.EncodeAsPieces(text)]
tokens = [Token(token, i) for i, token in enumerate(str_tokens)]
for start_token in self._start_tokens:
tokens.insert(0, Token(start_token, 0))
for end_token in self._end_tokens:
tokens.append(Token(end_token, -1))
return tokens
@classmethod
def from_params(cls, params: Params) -> 'SentencePieceTokenizer':
model_path = params.pop('model_path', None)
word_splitter = WordSplitter.from_params(params.pop('word_splitter', {}))
word_filter = WordFilter.from_params(params.pop('word_filter', {}))
word_stemmer = WordStemmer.from_params(params.pop('word_stemmer', {}))
start_tokens = params.pop('start_tokens', None)
end_tokens = params.pop('end_tokens', None)
vocab_size = params.pop('vocab_size', 32000)
model_type = params.pop('model_type', 'unigram')
params.assert_empty(cls.__name__)
return cls(model_path=model_path,
word_splitter=word_splitter,
word_filter=word_filter,
word_stemmer=word_stemmer,
start_tokens=start_tokens,
end_tokens=end_tokens,
vocab_size=vocab_size,
model_type=model_type)
from typing import Dict
import logging
from my_library.tokenizers import SentencePieceTokenizer
from overrides import overrides
from allennlp.common import Params
from allennlp.common.checks import ConfigurationError
from allennlp.common.file_utils import cached_path
from allennlp.common.util import START_SYMBOL, END_SYMBOL
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import TextField
from allennlp.data.instance import Instance
from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@DatasetReader.register("seq2seqnew")
class Seq2SeqNewDatasetReader(DatasetReader):
"""
Read a tsv file containing paired sequences, and create a dataset suitable for a
``SimpleSeq2Seq`` model, or any model with a matching API.
Expected format for each input line: <source_sequence_string>\t<target_sequence_string>
The output of ``read`` is a list of ``Instance`` s with the fields:
source_tokens: ``TextField`` and
target_tokens: ``TextField``
`START_SYMBOL` and `END_SYMBOL` tokens are added to the source and target sequences.
Parameters
----------
source_tokenizer : ``Tokenizer``, optional
Tokenizer to use to split the input sequences into words or other kinds of tokens. Defaults
to ``WordTokenizer()``.
target_tokenizer : ``Tokenizer``, optional
Tokenizer to use to split the output sequences (during training) into words or other kinds
of tokens. Defaults to ``source_tokenizer``.
source_token_indexers : ``Dict[str, TokenIndexer]``, optional
Indexers used to define input (source side) token representations. Defaults to
``{"tokens": SingleIdTokenIndexer()}``.
target_token_indexers : ``Dict[str, TokenIndexer]``, optional
Indexers used to define output (target side) token representations. Defaults to
``source_token_indexers``.
source_add_start_token : bool, (optional, default=True)
Whether or not to add `START_SYMBOL` to the beginning of the source sequence.
"""
def __init__(self,
source_tokenizer: Tokenizer = None,
target_tokenizer: Tokenizer = None,
source_token_indexers: Dict[str, TokenIndexer] = None,
target_token_indexers: Dict[str, TokenIndexer] = None,
source_add_start_token: bool = True,
lazy: bool = False) -> None:
super().__init__(lazy)
self._source_tokenizer = source_tokenizer or WordTokenizer()
self._target_tokenizer = target_tokenizer or self._source_tokenizer
self._source_token_indexers = source_token_indexers or {"tokens": SingleIdTokenIndexer()}
self._target_token_indexers = target_token_indexers or self._source_token_indexers
self._source_add_start_token = source_add_start_token
@overrides
def _read(self, file_path):
source_trained, target_trained = True, True
if isinstance(self._source_tokenizer, SentencePieceTokenizer):
source_trained = self._source_tokenizer.trained
if isinstance(self._target_tokenizer, SentencePieceTokenizer):
target_trained = self._target_tokenizer.trained
any_untrained = not (source_trained and target_trained)
if self.lazy and any_untrained:
raise ConfigurationError(
"Cannot run lazily without both source_tokenizer and target_tokenizer trained")
if any_untrained:
sources, targets = [], []
with open(cached_path(file_path), "r") as data_file:
logger.info("Reading instances from lines in file at: %s", file_path)
for line_num, line in enumerate(data_file):
line = line.strip("\n")
if not line:
continue
line_parts = line.split('\t')
if len(line_parts) != 2:
raise ConfigurationError("Invalid line format: %s (line number %d)" % (line, line_num + 1))
source_sequence, target_sequence = line_parts
if any_untrained:
sources.append(source_sequence)
targets.append(target_sequence)
else:
yield self.text_to_instance(source_sequence, target_sequence)
if any_untrained:
if self._source_tokenizer == self._target_tokenizer:
self._source_tokenizer.train(sources+targets)
else:
self._source_tokenizer.train(sources)
self._target_tokenizer.train(targets)
for source_sequence, target_sequence in zip(sources, targets):
yield self.text_to_instance(source_sequence, target_sequence)
@overrides
def text_to_instance(self, source_string: str, target_string: str = None) -> Instance: # type: ignore
# pylint: disable=arguments-differ
tokenized_source = self._source_tokenizer.tokenize(source_string)
if self._source_add_start_token:
tokenized_source.insert(0, Token(START_SYMBOL))
tokenized_source.append(Token(END_SYMBOL))
source_field = TextField(tokenized_source, self._source_token_indexers)
if target_string is not None:
tokenized_target = self._target_tokenizer.tokenize(target_string)
tokenized_target.insert(0, Token(START_SYMBOL))
tokenized_target.append(Token(END_SYMBOL))
target_field = TextField(tokenized_target, self._target_token_indexers)
return Instance({"source_tokens": source_field, "target_tokens": target_field})
else:
return Instance({'source_tokens': source_field})
@classmethod
def from_params(cls, params: Params) -> 'Seq2SeqNewDatasetReader':
source_tokenizer_type = params.pop('source_tokenizer', None)
source_tokenizer = None if source_tokenizer_type is None else Tokenizer.from_params(source_tokenizer_type)
target_tokenizer_type = params.pop('target_tokenizer', None)
target_tokenizer = None if target_tokenizer_type is None else Tokenizer.from_params(target_tokenizer_type)
source_indexers_type = params.pop('source_token_indexers', None)
source_add_start_token = params.pop_bool('source_add_start_token', True)
if source_indexers_type is None:
source_token_indexers = None
else:
source_token_indexers = TokenIndexer.dict_from_params(source_indexers_type)
target_indexers_type = params.pop('target_token_indexers', None)
if target_indexers_type is None:
target_token_indexers = None
else:
target_token_indexers = TokenIndexer.dict_from_params(target_indexers_type)
lazy = params.pop('lazy', False)
params.assert_empty(cls.__name__)
return Seq2SeqNewDatasetReader(source_tokenizer=source_tokenizer,
target_tokenizer=target_tokenizer,
source_token_indexers=source_token_indexers,
target_token_indexers=target_token_indexers,
source_add_start_token=source_add_start_token,
lazy=lazy)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment