Last active
June 17, 2022 22:20
-
-
Save PetrochukM/afaa3613a99a8e7213d2efdd02ae4762 to your computer and use it in GitHub Desktop.
Implemented a Top K Viterbi Decoder algorithm in PyTorch. Useful for Conditional Random Fields (CRFs)-based probabilistic graphical modelling. Learn more here: https://nlp.stanford.edu/joberant/esslli_2016/kbest-ict.pdf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
# Credits to AllenNLP for the base implementation and base tests: | |
# https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py#L174 | |
# Modified AllenNLP `viterbi_decode` to support `top_k` sequences efficiently. | |
def viterbi_decode(tag_sequence: torch.Tensor, transition_matrix: torch.Tensor, top_k: int=5): | |
""" | |
Perform Viterbi decoding in log space over a sequence given a transition matrix | |
specifying pairwise (transition) potentials between tags and a matrix of shape | |
(sequence_length, num_tags) specifying unary potentials for possible tags per | |
timestep. | |
Parameters | |
---------- | |
tag_sequence : torch.Tensor, required. | |
A tensor of shape (sequence_length, num_tags) representing scores for | |
a set of tags over a given sequence. | |
transition_matrix : torch.Tensor, required. | |
A tensor of shape (num_tags, num_tags) representing the binary potentials | |
for transitioning between a given pair of tags. | |
top_k : int, required. | |
Integer defining the top number of paths to decode. | |
Returns | |
------- | |
viterbi_path : List[int] | |
The tag indices of the maximum likelihood tag sequence. | |
viterbi_score : float | |
The score of the viterbi path. | |
""" | |
sequence_length, num_tags = list(tag_sequence.size()) | |
path_scores = [] | |
path_indices = [] | |
# At the beginning, the maximum number of permutations is 1; therefore, we unsqueeze(0) | |
# to allow for 1 permutation. | |
path_scores.append(tag_sequence[0, :].unsqueeze(0)) | |
# assert path_scores[0].size() == (n_permutations, num_tags) | |
# Evaluate the scores for all possible paths. | |
for timestep in range(1, sequence_length): | |
# Add pairwise potentials to current scores. | |
# assert path_scores[timestep - 1].size() == (n_permutations, num_tags) | |
summed_potentials = path_scores[timestep - 1].unsqueeze(2) + transition_matrix | |
summed_potentials = summed_potentials.view(-1, num_tags) | |
# Best pairwise potential path score from the previous timestep. | |
max_k = min(summed_potentials.size()[0], top_k) | |
scores, paths = torch.topk(summed_potentials, k=max_k, dim=0) | |
# assert scores.size() == (n_permutations, num_tags) | |
# assert paths.size() == (n_permutations, num_tags) | |
scores = tag_sequence[timestep, :] + scores | |
# assert scores.size() == (n_permutations, num_tags) | |
path_scores.append(scores) | |
path_indices.append(paths.squeeze()) | |
# Construct the most likely sequence backwards. | |
path_scores = path_scores[-1].view(-1) | |
max_k = min(path_scores.size()[0], top_k) | |
viterbi_scores, best_paths = torch.topk(path_scores, k=max_k, dim=0) | |
viterbi_paths = [] | |
for i in range(max_k): | |
viterbi_path = [best_paths[i]] | |
for backward_timestep in reversed(path_indices): | |
viterbi_path.append(int(backward_timestep.view(-1)[viterbi_path[-1]])) | |
# Reverse the backward path. | |
viterbi_path.reverse() | |
# Viterbi paths uses (num_tags * n_permutations) nodes; therefore, we need to modulo. | |
viterbi_path = [j % num_tags for j in viterbi_path] | |
viterbi_paths.append(viterbi_path) | |
return viterbi_paths, viterbi_scores | |
############################### | |
# Testing | |
############################### | |
from torch.autograd import Variable | |
from tqdm import tqdm | |
import random | |
import numpy as np | |
def test_greedy(): | |
# Test Viterbi decoding is equal to greedy decoding with no pairwise potentials. | |
sequence_logits = Variable(torch.rand([5, 9])) | |
transition_matrix = torch.zeros([9, 9]) | |
indices, _ = viterbi_decode(sequence_logits.data, transition_matrix) | |
_, argmax_indices = torch.max(sequence_logits, 1) | |
assert indices[0] == argmax_indices.data.squeeze().tolist() | |
test_greedy() | |
print('PASSED TEST GREEDY') | |
def test_inf(): | |
# Test that pairwise potentials effect the sequence correctly and that | |
# viterbi_decode can handle -inf values. | |
sequence_logits = torch.FloatTensor([[0, 0, 0, 3, 4], [0, 0, 0, 3, 4], [0, 0, 0, 3, 4], | |
[0, 0, 0, 3, 4], [0, 0, 0, 3, 4], [0, 0, 0, 3, 4]]) | |
# The same tags shouldn't appear sequentially. | |
transition_matrix = torch.zeros([5, 5]) | |
for i in range(5): | |
transition_matrix[i, i] = float("-inf") | |
indices, _ = viterbi_decode(sequence_logits, transition_matrix) | |
assert indices[0] == [3, 4, 3, 4, 3, 4] | |
test_inf() | |
print('PASSED TEST INF') | |
def test_ties(): | |
# Test that unbalanced pairwise potentials break ties | |
# between paths with equal unary potentials. | |
sequence_logits = torch.FloatTensor([[0, 0, 0, 4, 4], [0, 0, 0, 4, 4], [0, 0, 0, 4, 4], | |
[0, 0, 0, 4, 4], [0, 0, 0, 4, 4], [0, 0, 0, 4, 4]]) | |
# The 5th tag has a penalty for appearing sequentially | |
# or for transitioning to the 4th tag, making the best | |
# path uniquely to take the 4th tag only. | |
transition_matrix = torch.zeros([5, 5]) | |
transition_matrix[4, 4] = -10 | |
transition_matrix[4, 3] = -10 | |
indices, _ = viterbi_decode(sequence_logits, transition_matrix) | |
assert indices[0] == [3, 3, 3, 3, 3, 3] | |
test_ties() | |
print('PASSED TEST TIES') | |
def test_transitions(): | |
sequence_logits = torch.FloatTensor([[1, 0, 0, 4], [1, 0, 6, 2], [0, 3, 0, 4]]) | |
# Best path would normally be [3, 2, 3] but we add a | |
# potential from 2 -> 1, making [3, 2, 1] the best path. | |
transition_matrix = torch.zeros([4, 4]) | |
transition_matrix[0, 0] = 1 | |
transition_matrix[2, 1] = 5 | |
indices, value = viterbi_decode(sequence_logits, transition_matrix) | |
assert indices[0] == [3, 2, 1] | |
assert value[0] == 18 | |
test_transitions() | |
print('PASSED TEST TRANSITIONS') | |
# Use the brute decoding as truth | |
def brute_decode(tag_sequence: torch.Tensor, transition_matrix: torch.Tensor, top_k: int=5): | |
""" | |
Top-k decoder that uses brute search instead of the Viterbi Decode dynamic programing algorithm | |
""" | |
# Create all possible sequences | |
sequence_length, num_tags = list(tag_sequence.size()) | |
sequences = [[]] | |
for i in range(len(tag_sequence)): | |
new_sequences = [] | |
for j in range(len(tag_sequence[i])): | |
for sequence in sequences: | |
new_sequences.append(sequence[:] + [j]) | |
sequences = new_sequences | |
# Score | |
scored_sequences = [] | |
for sequence in sequences: | |
emission_score = sum([tag_sequence[i, j] for i, j in enumerate(sequence)]) | |
transition_score = sum( | |
[transition_matrix[sequence[i - 1], sequence[i]] for i in range(1, len(sequence))]) | |
score = emission_score + transition_score | |
scored_sequences.append((score, sequence)) | |
# Get the top k scores / paths | |
top_k_sequences = sorted(scored_sequences, key=lambda r: r[0], reverse=True)[:top_k] | |
scores, paths = zip(*top_k_sequences) | |
return paths, scores | |
def test_brute(): | |
# Run 100 randomly generated parameters and compare the outputs. | |
for i in tqdm(range(100)): | |
num_tags = random.randint(1, 5) | |
seq_len = random.randint(1, 5) | |
k = random.randint(1, 5) | |
sequence_logits = torch.rand([seq_len, num_tags]) | |
transition_matrix = torch.rand([num_tags, num_tags]) | |
viterbi_paths_v1, viterbi_scores_v1 = viterbi_decode( | |
sequence_logits, transition_matrix, top_k=k) | |
viterbi_path_brute, viterbi_score_brute = brute_decode( | |
sequence_logits, transition_matrix, top_k=k) | |
np.testing.assert_almost_equal( | |
list(viterbi_score_brute), viterbi_scores_v1.tolist(), decimal=3) | |
test_brute() | |
print('PASSED TEST BRUTE') |
Is there a version for batch processing instead of only processing one sequence at one time?
I run this code and the Ties test fails... is this supposed to pass?
print(indices) before the assert statement in test_ties() returns
[[3, 3, 3, 3, 3, tensor(4)], [3, 3, 3, 3, 3, tensor(3)], [3, 3, 3, 3, 4, tensor(2)], [3, 3, 3, 3, 4, tensor(1)], [3, 3, 3, 3, 4, tensor(0)]]
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nice! Thanks for sharing. What license does this code have? Apache 2.0 like AllenNLP?