Created
July 15, 2019 19:20
-
-
Save mayhewsw/26939faf0a7190a6d174893a31ba0ac8 to your computer and use it in GitHub Desktop.
Flair in Allennlp (quick and dirty)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, List, Tuple, Any | |
from flair.data import Dictionary | |
from flair.embeddings import FlairEmbeddings | |
from overrides import overrides | |
from allennlp.common.checks import ConfigurationError | |
from allennlp.common.util import pad_sequence_to_length | |
from allennlp.data.tokenizers.token import Token | |
from allennlp.data.token_indexers.token_indexer import TokenIndexer | |
from allennlp.data.vocabulary import Vocabulary | |
from allennlp.data.tokenizers.character_tokenizer import CharacterTokenizer | |
# Not happy about this, but apparently it needs to be done. | |
import logging.config | |
logging.config.dictConfig({ | |
'version': 1, | |
'disable_existing_loggers': False, | |
'formatters': { | |
'standard': { | |
'format': '%(asctime)s - %(levelname)s - %(name)s - %(message)s' | |
}, | |
}, | |
'handlers': { | |
'console': { | |
'level': 'INFO', | |
'class': 'logging.StreamHandler', | |
'formatter': 'standard', | |
'stream': 'ext://sys.stdout' | |
}, | |
}, | |
'loggers': { | |
'flair': { | |
'handlers': ['console'], | |
'level': 'INFO', | |
'propagate': False | |
} | |
}, | |
'root': { | |
'handlers': ['console'], | |
'level': 'DEBUG' | |
} | |
}) | |
@TokenIndexer.register("flair") | |
class FlairCharIndexer(TokenIndexer[List[int]]): | |
""" | |
This :class:`TokenIndexer` represents tokens as lists of character indices. | |
Parameters | |
---------- | |
pretrained_model: ``str`` | |
The name of the Flair embeddings that you plan to use ("news-forward", for example). These are only | |
used here to load the character mapping dictionary and to get the is_forward_lm variable. | |
namespace : ``str``, optional (default=``flair``) | |
Since Flair uses it's own dictionaries, the vocabulary is actually not used. | |
character_tokenizer : ``CharacterTokenizer``, optional (default=``CharacterTokenizer()``) | |
We use a :class:`CharacterTokenizer` to handle splitting tokens into characters, as it has | |
options for byte encoding and other things. The default here is to instantiate a | |
``CharacterTokenizer`` with its default parameters, which uses unicode characters and | |
retains casing. | |
start_chars : ``List[str]``, optional (default=``None``) | |
These are prepended to the tokens provided to ``tokens_to_indices``. | |
end_chars : ``List[str]``, optional (default=``None``) | |
These are appended to the tokens provided to ``tokens_to_indices``. | |
min_padding_length: ``int``, optional (default=``0``) | |
We use this value as the minimum length of padding. Usually used with :class:``CnnEncoder``, its | |
value should be set to the maximum value of ``ngram_filter_sizes`` correspondingly. | |
""" | |
# pylint: disable=no-self-use | |
def __init__(self, | |
pretrained_model: str, | |
namespace: str = 'flair', | |
character_tokenizer: CharacterTokenizer = CharacterTokenizer(), | |
start_chars: List[str] = None, | |
end_chars: List[str] = None, | |
min_padding_length: int = 0) -> None: | |
self._min_padding_length = min_padding_length | |
self._namespace = namespace | |
self._character_tokenizer = character_tokenizer | |
# by default, start should be "\n" and end should be " " (according to official flair code) | |
self._start_chars = [Token(st) for st in (start_chars or ["\n"])] | |
self._end_chars = [Token(et) for et in (end_chars or [" "])] | |
# really we are just loading this for the character dictionary and the is_forward_lm variable | |
# this should match the model used in the embedder. | |
flair_embs = FlairEmbeddings(pretrained_model) | |
self.dictionary: Dictionary = flair_embs.lm.dictionary | |
self.is_forward_lm = flair_embs.is_forward_lm | |
# Fixed for now (to match Flair), but could technically be anything... | |
self.separator = " " | |
@overrides | |
def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]): | |
if token.text is None: | |
raise ConfigurationError('TokenCharactersIndexer needs a tokenizer that retains text') | |
for character in self._character_tokenizer.tokenize(token.text): | |
counter[self._namespace][character.text] += 1 | |
@overrides | |
def tokens_to_indices(self, | |
tokens: List[Token], | |
vocabulary: Vocabulary, | |
index_name: str) -> Dict[str, List[List[int]]]: | |
indices: List[int] = [] | |
# don't want to modify tokens for everybody else | |
mytoks = list(tokens) | |
span_start = 0 | |
for c in self._start_chars: | |
indices.append(self.dictionary.get_idx_for_item(c.text)) | |
span_start += len(c.text) | |
spans: List[Tuple[int, int]] = [] | |
if not self.is_forward_lm: | |
# reverse in place. the characters in the words are not reversed though. | |
mytoks = mytoks[::-1] | |
for i, token in enumerate(mytoks): | |
if token.text is None: | |
raise ConfigurationError('TokenCharactersIndexer needs a tokenizer that retains text') | |
chars = self._character_tokenizer.tokenize(token.text) | |
if not self.is_forward_lm: | |
chars = chars[::-1] | |
for character in chars: | |
index = self.dictionary.get_idx_for_item(character.text) | |
indices.append(index) | |
# this prevents there being a separator between the last token and the end characters. | |
if i < len(mytoks)-1: | |
for c in self.separator: | |
indices.append(self.dictionary.get_idx_for_item(c)) | |
# NOTICE: flair offsets are weird. The beginning offset uses one character *before* the token | |
# and the end offset uses one character *after* the token. | |
# In practice though, we never actually use the first element of the span because | |
# flair uses two separate unidirectional LMs. | |
spans.append((span_start-1, span_start+len(token.text))) | |
span_start += len(token.text) + len(self.separator) | |
if not self.is_forward_lm: | |
# when using spans to select tokens from the indices, | |
# reversing the spans means that we will select indices | |
# the right way around. O_o | |
spans = spans[::-1] | |
for c in self._end_chars: | |
indices.append(self.dictionary.get_idx_for_item(c.text)) | |
return {index_name: indices, index_name + "_spans": spans} | |
@overrides | |
def get_padding_lengths(self, token: List[int]) -> Dict[str, int]: | |
return {} | |
@overrides | |
def get_padding_token(self) -> List[int]: | |
# following flair code, we use a space as the default padding token. | |
return self.dictionary.get_idx_for_item(" ") | |
@overrides | |
def pad_token_sequence(self, | |
tokens: Dict[str, List[Any]], | |
desired_num_tokens: Dict[str, int], | |
padding_lengths: Dict[str, int]) -> Dict[str, List[int]]: | |
out = {} | |
for key, val in tokens.items(): | |
if "spans" in key: | |
# spans are padded with tuples instead of integers. | |
out[key] = pad_sequence_to_length(val, desired_num_tokens[key], default_value=lambda: (0, 0)) | |
else: | |
# It is EXTREMELY important that the padding be a space, as this is how the CLM was trained. | |
out[key] = pad_sequence_to_length(val, desired_num_tokens[key], | |
default_value=lambda: self.dictionary.get_idx_for_item(" ")) | |
# When creating a mask, Allennlp doesn't know the difference between a tensor of *characters* | |
# and a tensor of *tokens*. To get around this, we do the same trick that was used in the | |
# Bert Indexer and inject a mask here, since the span field holds the token information. | |
tok_key = list(filter(lambda s: "spans" in s, tokens.keys()))[0] | |
num_toks = len(tokens[tok_key]) | |
desired_toks = desired_num_tokens[tok_key] | |
mask = [1]*num_toks + [0]*(desired_toks - num_toks) | |
out["mask"] = mask | |
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Flair Embedder. | |
""" | |
import logging | |
import torch | |
from allennlp.modules.span_extractors import EndpointSpanExtractor | |
from allennlp.nn.util import get_text_field_mask | |
from flair.embeddings import FlairEmbeddings | |
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder | |
logger = logging.getLogger(__name__) | |
class FlairEmbedder(TokenEmbedder): | |
def __init__(self, flair_model: FlairEmbeddings) -> None: | |
super().__init__() | |
self.flair_model = flair_model | |
self.pretrain_name = self.flair_model.name | |
self.output_dim = flair_model.lm.hidden_size | |
for param in self.flair_model.lm.parameters(): | |
param.requires_grad = False | |
# In Flair, every LM is unidirectional going forwards. | |
# We always extract on the right side. | |
comb_string = "y" | |
self.span_extractor = EndpointSpanExtractor(input_dim=self.flair_model.lm.hidden_size, combination=comb_string) | |
# Set model to None so it is reloaded in the forward. | |
# Have no idea what is happening, but something is modifying the model | |
# somewhere between init and forward. Reloading in forward works. | |
self.flair_model = None | |
def get_output_dim(self) -> int: | |
return self.output_dim | |
def forward(self, input_ids: torch.LongTensor, spans: torch.LongTensor) -> torch.Tensor: | |
# Super hack. Model is changed between init and forward and we don't know how. | |
if self.flair_model is None: | |
self.flair_model = FlairEmbeddings(self.pretrain_name) | |
for param in self.flair_model.lm.parameters(): | |
param.requires_grad = False | |
with torch.no_grad(): | |
# Doesn't matter what this key is, just needs to be a dict. | |
mask = get_text_field_mask({"chars": input_ids}) | |
# trick: fake spans have a sum that is 0 or less. | |
# Look at FlairCharIndexer.pad_token_sequence() to see what the default tuple value is | |
mask_spans = (spans.sum(dim=2) > 0).long() | |
# Shape: (max_char_seq_len, batch_size) | |
batch_second_input_ids = input_ids.transpose(0, 1) | |
max_seq_len, batch_size = batch_second_input_ids.shape | |
hidden = self.flair_model.lm.init_hidden(batch_size) | |
prediction, batch_second_rnn_output, _ = self.flair_model.lm.forward(batch_second_input_ids, hidden) | |
# Shape: (batch_size, max_char_seq_len) | |
rnn_output = batch_second_rnn_output.transpose(1, 0) | |
word_embeddings = self.span_extractor(rnn_output.contiguous(), spans, mask, mask_spans).contiguous() | |
return word_embeddings | |
@TokenEmbedder.register("flair-pretrained") | |
class PretrainedFlairEmbedder(FlairEmbedder): | |
# pylint: disable=line-too-long | |
""" | |
Parameters | |
---------- | |
pretrained_model: ``str`` | |
Either the name of the pretrained model to use (e.g. 'news-forward'), | |
If the name is a key in the list of pretrained models at | |
https://github.com/zalandoresearch/flair/blob/master/flair/embeddings.py#L834 | |
the corresponding path will be used; otherwise it will be interpreted as a path or URL. | |
""" | |
def __init__(self, pretrained_model: str) -> None: | |
flair_embs = FlairEmbeddings(pretrained_model) | |
super().__init__(flair_model=flair_embs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment