Skip to content

Instantly share code, notes, and snippets.

@mayhewsw
Created July 15, 2019 19:20
Show Gist options
  • Save mayhewsw/26939faf0a7190a6d174893a31ba0ac8 to your computer and use it in GitHub Desktop.
Save mayhewsw/26939faf0a7190a6d174893a31ba0ac8 to your computer and use it in GitHub Desktop.
Flair in Allennlp (quick and dirty)
from typing import Dict, List, Tuple, Any
from flair.data import Dictionary
from flair.embeddings import FlairEmbeddings
from overrides import overrides
from allennlp.common.checks import ConfigurationError
from allennlp.common.util import pad_sequence_to_length
from allennlp.data.tokenizers.token import Token
from allennlp.data.token_indexers.token_indexer import TokenIndexer
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.tokenizers.character_tokenizer import CharacterTokenizer
# Not happy about this, but apparently it needs to be done.
import logging.config
logging.config.dictConfig({
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'standard': {
'format': '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
},
},
'handlers': {
'console': {
'level': 'INFO',
'class': 'logging.StreamHandler',
'formatter': 'standard',
'stream': 'ext://sys.stdout'
},
},
'loggers': {
'flair': {
'handlers': ['console'],
'level': 'INFO',
'propagate': False
}
},
'root': {
'handlers': ['console'],
'level': 'DEBUG'
}
})
@TokenIndexer.register("flair")
class FlairCharIndexer(TokenIndexer[List[int]]):
"""
This :class:`TokenIndexer` represents tokens as lists of character indices.
Parameters
----------
pretrained_model: ``str``
The name of the Flair embeddings that you plan to use ("news-forward", for example). These are only
used here to load the character mapping dictionary and to get the is_forward_lm variable.
namespace : ``str``, optional (default=``flair``)
Since Flair uses it's own dictionaries, the vocabulary is actually not used.
character_tokenizer : ``CharacterTokenizer``, optional (default=``CharacterTokenizer()``)
We use a :class:`CharacterTokenizer` to handle splitting tokens into characters, as it has
options for byte encoding and other things. The default here is to instantiate a
``CharacterTokenizer`` with its default parameters, which uses unicode characters and
retains casing.
start_chars : ``List[str]``, optional (default=``None``)
These are prepended to the tokens provided to ``tokens_to_indices``.
end_chars : ``List[str]``, optional (default=``None``)
These are appended to the tokens provided to ``tokens_to_indices``.
min_padding_length: ``int``, optional (default=``0``)
We use this value as the minimum length of padding. Usually used with :class:``CnnEncoder``, its
value should be set to the maximum value of ``ngram_filter_sizes`` correspondingly.
"""
# pylint: disable=no-self-use
def __init__(self,
pretrained_model: str,
namespace: str = 'flair',
character_tokenizer: CharacterTokenizer = CharacterTokenizer(),
start_chars: List[str] = None,
end_chars: List[str] = None,
min_padding_length: int = 0) -> None:
self._min_padding_length = min_padding_length
self._namespace = namespace
self._character_tokenizer = character_tokenizer
# by default, start should be "\n" and end should be " " (according to official flair code)
self._start_chars = [Token(st) for st in (start_chars or ["\n"])]
self._end_chars = [Token(et) for et in (end_chars or [" "])]
# really we are just loading this for the character dictionary and the is_forward_lm variable
# this should match the model used in the embedder.
flair_embs = FlairEmbeddings(pretrained_model)
self.dictionary: Dictionary = flair_embs.lm.dictionary
self.is_forward_lm = flair_embs.is_forward_lm
# Fixed for now (to match Flair), but could technically be anything...
self.separator = " "
@overrides
def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
if token.text is None:
raise ConfigurationError('TokenCharactersIndexer needs a tokenizer that retains text')
for character in self._character_tokenizer.tokenize(token.text):
counter[self._namespace][character.text] += 1
@overrides
def tokens_to_indices(self,
tokens: List[Token],
vocabulary: Vocabulary,
index_name: str) -> Dict[str, List[List[int]]]:
indices: List[int] = []
# don't want to modify tokens for everybody else
mytoks = list(tokens)
span_start = 0
for c in self._start_chars:
indices.append(self.dictionary.get_idx_for_item(c.text))
span_start += len(c.text)
spans: List[Tuple[int, int]] = []
if not self.is_forward_lm:
# reverse in place. the characters in the words are not reversed though.
mytoks = mytoks[::-1]
for i, token in enumerate(mytoks):
if token.text is None:
raise ConfigurationError('TokenCharactersIndexer needs a tokenizer that retains text')
chars = self._character_tokenizer.tokenize(token.text)
if not self.is_forward_lm:
chars = chars[::-1]
for character in chars:
index = self.dictionary.get_idx_for_item(character.text)
indices.append(index)
# this prevents there being a separator between the last token and the end characters.
if i < len(mytoks)-1:
for c in self.separator:
indices.append(self.dictionary.get_idx_for_item(c))
# NOTICE: flair offsets are weird. The beginning offset uses one character *before* the token
# and the end offset uses one character *after* the token.
# In practice though, we never actually use the first element of the span because
# flair uses two separate unidirectional LMs.
spans.append((span_start-1, span_start+len(token.text)))
span_start += len(token.text) + len(self.separator)
if not self.is_forward_lm:
# when using spans to select tokens from the indices,
# reversing the spans means that we will select indices
# the right way around. O_o
spans = spans[::-1]
for c in self._end_chars:
indices.append(self.dictionary.get_idx_for_item(c.text))
return {index_name: indices, index_name + "_spans": spans}
@overrides
def get_padding_lengths(self, token: List[int]) -> Dict[str, int]:
return {}
@overrides
def get_padding_token(self) -> List[int]:
# following flair code, we use a space as the default padding token.
return self.dictionary.get_idx_for_item(" ")
@overrides
def pad_token_sequence(self,
tokens: Dict[str, List[Any]],
desired_num_tokens: Dict[str, int],
padding_lengths: Dict[str, int]) -> Dict[str, List[int]]:
out = {}
for key, val in tokens.items():
if "spans" in key:
# spans are padded with tuples instead of integers.
out[key] = pad_sequence_to_length(val, desired_num_tokens[key], default_value=lambda: (0, 0))
else:
# It is EXTREMELY important that the padding be a space, as this is how the CLM was trained.
out[key] = pad_sequence_to_length(val, desired_num_tokens[key],
default_value=lambda: self.dictionary.get_idx_for_item(" "))
# When creating a mask, Allennlp doesn't know the difference between a tensor of *characters*
# and a tensor of *tokens*. To get around this, we do the same trick that was used in the
# Bert Indexer and inject a mask here, since the span field holds the token information.
tok_key = list(filter(lambda s: "spans" in s, tokens.keys()))[0]
num_toks = len(tokens[tok_key])
desired_toks = desired_num_tokens[tok_key]
mask = [1]*num_toks + [0]*(desired_toks - num_toks)
out["mask"] = mask
return out
"""
Flair Embedder.
"""
import logging
import torch
from allennlp.modules.span_extractors import EndpointSpanExtractor
from allennlp.nn.util import get_text_field_mask
from flair.embeddings import FlairEmbeddings
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder
logger = logging.getLogger(__name__)
class FlairEmbedder(TokenEmbedder):
def __init__(self, flair_model: FlairEmbeddings) -> None:
super().__init__()
self.flair_model = flair_model
self.pretrain_name = self.flair_model.name
self.output_dim = flair_model.lm.hidden_size
for param in self.flair_model.lm.parameters():
param.requires_grad = False
# In Flair, every LM is unidirectional going forwards.
# We always extract on the right side.
comb_string = "y"
self.span_extractor = EndpointSpanExtractor(input_dim=self.flair_model.lm.hidden_size, combination=comb_string)
# Set model to None so it is reloaded in the forward.
# Have no idea what is happening, but something is modifying the model
# somewhere between init and forward. Reloading in forward works.
self.flair_model = None
def get_output_dim(self) -> int:
return self.output_dim
def forward(self, input_ids: torch.LongTensor, spans: torch.LongTensor) -> torch.Tensor:
# Super hack. Model is changed between init and forward and we don't know how.
if self.flair_model is None:
self.flair_model = FlairEmbeddings(self.pretrain_name)
for param in self.flair_model.lm.parameters():
param.requires_grad = False
with torch.no_grad():
# Doesn't matter what this key is, just needs to be a dict.
mask = get_text_field_mask({"chars": input_ids})
# trick: fake spans have a sum that is 0 or less.
# Look at FlairCharIndexer.pad_token_sequence() to see what the default tuple value is
mask_spans = (spans.sum(dim=2) > 0).long()
# Shape: (max_char_seq_len, batch_size)
batch_second_input_ids = input_ids.transpose(0, 1)
max_seq_len, batch_size = batch_second_input_ids.shape
hidden = self.flair_model.lm.init_hidden(batch_size)
prediction, batch_second_rnn_output, _ = self.flair_model.lm.forward(batch_second_input_ids, hidden)
# Shape: (batch_size, max_char_seq_len)
rnn_output = batch_second_rnn_output.transpose(1, 0)
word_embeddings = self.span_extractor(rnn_output.contiguous(), spans, mask, mask_spans).contiguous()
return word_embeddings
@TokenEmbedder.register("flair-pretrained")
class PretrainedFlairEmbedder(FlairEmbedder):
# pylint: disable=line-too-long
"""
Parameters
----------
pretrained_model: ``str``
Either the name of the pretrained model to use (e.g. 'news-forward'),
If the name is a key in the list of pretrained models at
https://github.com/zalandoresearch/flair/blob/master/flair/embeddings.py#L834
the corresponding path will be used; otherwise it will be interpreted as a path or URL.
"""
def __init__(self, pretrained_model: str) -> None:
flair_embs = FlairEmbeddings(pretrained_model)
super().__init__(flair_model=flair_embs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment