This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# find the best sequence of labels for each sample in the batch | |
best_sequences = [] | |
emission_lengths = mask.int().sum(dim=1) | |
for i in range(batch_size): | |
# recover the original sentence length for the i-th sample in the batch | |
sample_length = emission_lengths[i].item() | |
# recover the max tag for the last timestep |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def decode(self, emissions, mask=None): | |
"""Find the most probable sequence of labels given the emissions using | |
the Viterbi algorithm. | |
Args: | |
emissions (torch.Tensor): Sequence of emissions for each label. | |
Shape (batch_size, seq_len, nb_labels) if batch_first is True, | |
(seq_len, batch_size, nb_labels) otherwise. | |
mask (torch.FloatTensor, optional): Tensor representing valid positions. | |
If None, all positions are considered valid. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _compute_log_partition(self, emissions, mask): | |
"""Compute the partition function in log-space using the forward-algorithm. | |
Args: | |
emissions (torch.Tensor): (batch_size, seq_len, nb_labels) | |
mask (Torch.FloatTensor): (batch_size, seq_len) | |
Returns: | |
torch.Tensor: the partition scores for each batch. | |
Shape of (batch_size,) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _compute_scores(self, emissions, tags, mask): | |
"""Compute the scores for a given batch of emissions with their tags. | |
Args: | |
emissions (torch.Tensor): (batch_size, seq_len, nb_labels) | |
tags (Torch.LongTensor): (batch_size, seq_len) | |
mask (Torch.FloatTensor): (batch_size, seq_len) | |
Returns: | |
torch.Tensor: Scores for each batch. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forward(self, emissions, tags, mask=None): | |
"""Compute the negative log-likelihood. See `log_likelihood` method.""" | |
nll = -self.log_likelihood(emissions, tags, mask=mask) | |
return nll | |
def log_likelihood(self, emissions, tags, mask=None): | |
"""Compute the probability of a sequence of tags given a sequence of | |
emissions scores. | |
Args: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
class CRF(nn.Module): | |
""" | |
Linear-chain Conditional Random Field (CRF). | |
Args: | |
nb_labels (int): number of labels in your tagset, including special symbols. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# vocabulary = OrderedDict() | |
input_length = None | |
vocabulary_size = max(vocabulary.values()) + 1 | |
weights_w2v = list(map(Word2Vec.__getitem__, vocabulary.keys())) | |
embedding_size len(weights_w2v[0]) | |
nb_classes = 5 | |
# CNN hyperparms | |
nb_filter = 64 | |
filter_length = 5 |
NewerOlder