Created
March 1, 2019 09:14
-
-
Save snakers4/f2188adf217baabad2ff6733d48cfaef to your computer and use it in GitHub Desktop.
Best pretraining for Russian language - embedding bag interfaces
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BertEmbeddingBag(nn.Module): | |
"""Construct the embeddings from word, position and token_type embeddings. | |
""" | |
def __init__(self, config): | |
super(BertEmbeddingBag, self).__init__() | |
# self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) | |
ngram_matrix=np.load(config.ngram_matrix_path) | |
self.old_bag = config.old_bag | |
if self.old_bag: | |
self.embedding_bag = OldFastTextEmbeddingBag(upkl(config.ngram_dict_path), | |
ngram_matrix, | |
config.device) | |
else: | |
self.embedding_bag = FastTextEmbeddingBag(ngram_matrix, | |
config.device) | |
assert ngram_matrix.shape[1] == config.emb_size | |
del ngram_matrix | |
self.hidden_size = config.hidden_size | |
self.emb_size = config.emb_size | |
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.emb_size) | |
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.emb_size) | |
self.linear = nn.Linear(self.emb_size, self.hidden_size) | |
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
def forward(self, input_ids, | |
token_type_ids=None, | |
pad_ids=None | |
): | |
# input_ids is a sequence of chars for embedding bag (!) | |
if self.old_bag: | |
seq_length = token_type_ids.size(1) | |
else: | |
seq_length = input_ids.size(1) | |
assert len(pad_ids.size()) == 2 | |
#assert pad_ids.size(1) == input_ids.size(1) | |
position_ids = torch.arange(seq_length, dtype=torch.long, device=token_type_ids.device) | |
position_ids = position_ids.unsqueeze(0).expand_as(token_type_ids) | |
if token_type_ids is None: | |
token_type_ids = torch.zeros_like(input_ids) | |
# words_embeddings = self.word_embeddings(input_ids) | |
words_embeddings = self.embedding_bag(input_ids, pad_ids).view((-1, | |
seq_length, | |
self.emb_size)) | |
position_embeddings = self.position_embeddings(position_ids) | |
token_type_embeddings = self.token_type_embeddings(token_type_ids) | |
embeddings = words_embeddings + position_embeddings + token_type_embeddings | |
if self.hidden_size != self.emb_size: | |
embeddings = self.linear(embeddings) | |
embeddings = self.LayerNorm(embeddings) | |
embeddings = self.dropout(embeddings) | |
return embeddings | |
class FastTextEmbeddingBag(EmbeddingBag): | |
def __init__(self, input_matrix, device): | |
#self.ngram_dict = ngram_dict | |
self.device = device | |
# Тут с размером матрицы нужно определиться | |
super().__init__(input_matrix.shape[0], input_matrix.shape[1]) | |
self.weight.data.copy_(torch.FloatTensor(input_matrix)) | |
def forward(self, ids, pad_ids): | |
#word_subinds = np.empty([0], dtype=np.int64) | |
#word_offsets = [0] | |
#for word in words: | |
#subinds = [self.ngram_dict[gram] for gram in word_ngrams(word) if gram in self.ngram_dict] | |
#if subinds == []: | |
#subinds.append(self.ngram_dict['#UNK#']) | |
#word_subinds = np.concatenate((word_subinds, subinds)) | |
#word_offsets.append(word_offsets[-1] + len(subinds)) | |
#word_offsets = word_offsets[:-1] | |
device=ids.device | |
pad_ids = pad_ids + 1 | |
ids = torch.reshape(ids, (-1, ids.shape[-1])) | |
pad_ids = pad_ids.view(-1) | |
word_subinds = torch.FloatTensor([]).to(device) | |
word_offsets = torch.FloatTensor([0.0]).to(device) | |
for i, word in enumerate(ids): | |
word_subinds = torch.cat((word_subinds, word.float()[:pad_ids[i]])) | |
word_offsets = torch.cat((word_offsets, torch.cumsum(pad_ids.float(), 0))) | |
word_offsets = word_offsets[:-1] | |
ind = word_subinds.long().to(device) | |
offsets = word_offsets.long().to(device) | |
return super().forward(ind, offsets) | |
class OldFastTextEmbeddingBag(EmbeddingBag): | |
def __init__(self, ngram_dict, input_matrix, device): | |
self.ngram_dict = ngram_dict | |
self.device = device | |
# Тут с размером матрицы нужно определиться | |
super().__init__(input_matrix.shape[0], input_matrix.shape[1]) | |
self.weight.data.copy_(torch.FloatTensor(input_matrix)) | |
def forward(self, words, pad_ids=None): | |
word_subinds = np.empty([0], dtype=np.int64) | |
word_offsets = [0] | |
for word in words: | |
subinds = [self.ngram_dict[gram] for gram in word_ngrams(word) if gram in self.ngram_dict] | |
if subinds == []: | |
subinds.append(self.ngram_dict['#UNK#']) | |
word_subinds = np.concatenate((word_subinds, subinds)) | |
word_offsets.append(word_offsets[-1] + len(subinds)) | |
word_offsets = word_offsets[:-1] | |
ind = torch.LongTensor(word_subinds).to(self.device) | |
offsets = torch.LongTensor(word_offsets).to(self.device) | |
return super().forward(ind, offsets) | |
class BertEmbeddingsWarmStart(nn.Module): | |
"""Construct the embeddings from word, position and token_type embeddings. | |
""" | |
def __init__(self, config): | |
super(BertEmbeddingsWarmStart, self).__init__() | |
# initialize with pre-trained embeddings | |
ngram_matrix = torch.from_numpy(np.load(config.ngram_matrix_path)).float() | |
self.word_embeddings = nn.Embedding.from_pretrained(ngram_matrix) | |
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) | |
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) | |
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | |
# any TensorFlow checkpoint file | |
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
def forward(self, input_ids, token_type_ids=None): | |
seq_length = input_ids.size(1) | |
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |
if token_type_ids is None: | |
token_type_ids = torch.zeros_like(input_ids) | |
words_embeddings = self.word_embeddings(input_ids) | |
position_embeddings = self.position_embeddings(position_ids) | |
token_type_embeddings = self.token_type_embeddings(token_type_ids) | |
embeddings = words_embeddings + position_embeddings + token_type_embeddings | |
embeddings = self.LayerNorm(embeddings) | |
embeddings = self.dropout(embeddings) | |
return embeddings |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment