Skip to content

Instantly share code, notes, and snippets.

View mtreviso's full-sized avatar

Marcos Treviso mtreviso

View GitHub Profile
# find the best sequence of labels for each sample in the batch
best_sequences = []
emission_lengths =
for i in range(batch_size):
# recover the original sentence length for the i-th sample in the batch
sample_length = emission_lengths[i].item()
# recover the max tag for the last timestep
def decode(self, emissions, mask=None):
"""Find the most probable sequence of labels given the emissions using
the Viterbi algorithm.
emissions (torch.Tensor): Sequence of emissions for each label.
Shape (batch_size, seq_len, nb_labels) if batch_first is True,
(seq_len, batch_size, nb_labels) otherwise.
mask (torch.FloatTensor, optional): Tensor representing valid positions.
If None, all positions are considered valid.
def _compute_log_partition(self, emissions, mask):
"""Compute the partition function in log-space using the forward-algorithm.
emissions (torch.Tensor): (batch_size, seq_len, nb_labels)
mask (Torch.FloatTensor): (batch_size, seq_len)
torch.Tensor: the partition scores for each batch.
Shape of (batch_size,)
def _compute_scores(self, emissions, tags, mask):
"""Compute the scores for a given batch of emissions with their tags.
emissions (torch.Tensor): (batch_size, seq_len, nb_labels)
tags (Torch.LongTensor): (batch_size, seq_len)
mask (Torch.FloatTensor): (batch_size, seq_len)
torch.Tensor: Scores for each batch.
mtreviso /
Last active January 3, 2021 06:10
Code displayed on medium.
def forward(self, emissions, tags, mask=None):
"""Compute the negative log-likelihood. See `log_likelihood` method."""
nll = -self.log_likelihood(emissions, tags, mask=mask)
return nll
def log_likelihood(self, emissions, tags, mask=None):
"""Compute the probability of a sequence of tags given a sequence of
emissions scores.
mtreviso /
Last active March 29, 2019 00:39
File displayed on medium.
import torch
from torch import nn
class CRF(nn.Module):
Linear-chain Conditional Random Field (CRF).
nb_labels (int): number of labels in your tagset, including special symbols.
mtreviso /
Last active October 12, 2017 14:34
Variable sentence size CNN + RNN with Keras
# vocabulary = OrderedDict()
input_length = None
vocabulary_size = max(vocabulary.values()) + 1
weights_w2v = list(map(Word2Vec.__getitem__, vocabulary.keys()))
embedding_size len(weights_w2v[0])
nb_classes = 5
# CNN hyperparms
nb_filter = 64
filter_length = 5