Skip to content

Instantly share code, notes, and snippets.

Last active June 17, 2022 22:20
Show Gist options
  • Save PetrochukM/afaa3613a99a8e7213d2efdd02ae4762 to your computer and use it in GitHub Desktop.
Save PetrochukM/afaa3613a99a8e7213d2efdd02ae4762 to your computer and use it in GitHub Desktop.
Implemented a Top K Viterbi Decoder algorithm in PyTorch. Useful for Conditional Random Fields (CRFs)-based probabilistic graphical modelling. Learn more here:
import torch
# Credits to AllenNLP for the base implementation and base tests:
# Modified AllenNLP `viterbi_decode` to support `top_k` sequences efficiently.
def viterbi_decode(tag_sequence: torch.Tensor, transition_matrix: torch.Tensor, top_k: int=5):
Perform Viterbi decoding in log space over a sequence given a transition matrix
specifying pairwise (transition) potentials between tags and a matrix of shape
(sequence_length, num_tags) specifying unary potentials for possible tags per
tag_sequence : torch.Tensor, required.
A tensor of shape (sequence_length, num_tags) representing scores for
a set of tags over a given sequence.
transition_matrix : torch.Tensor, required.
A tensor of shape (num_tags, num_tags) representing the binary potentials
for transitioning between a given pair of tags.
top_k : int, required.
Integer defining the top number of paths to decode.
viterbi_path : List[int]
The tag indices of the maximum likelihood tag sequence.
viterbi_score : float
The score of the viterbi path.
sequence_length, num_tags = list(tag_sequence.size())
path_scores = []
path_indices = []
# At the beginning, the maximum number of permutations is 1; therefore, we unsqueeze(0)
# to allow for 1 permutation.
path_scores.append(tag_sequence[0, :].unsqueeze(0))
# assert path_scores[0].size() == (n_permutations, num_tags)
# Evaluate the scores for all possible paths.
for timestep in range(1, sequence_length):
# Add pairwise potentials to current scores.
# assert path_scores[timestep - 1].size() == (n_permutations, num_tags)
summed_potentials = path_scores[timestep - 1].unsqueeze(2) + transition_matrix
summed_potentials = summed_potentials.view(-1, num_tags)
# Best pairwise potential path score from the previous timestep.
max_k = min(summed_potentials.size()[0], top_k)
scores, paths = torch.topk(summed_potentials, k=max_k, dim=0)
# assert scores.size() == (n_permutations, num_tags)
# assert paths.size() == (n_permutations, num_tags)
scores = tag_sequence[timestep, :] + scores
# assert scores.size() == (n_permutations, num_tags)
# Construct the most likely sequence backwards.
path_scores = path_scores[-1].view(-1)
max_k = min(path_scores.size()[0], top_k)
viterbi_scores, best_paths = torch.topk(path_scores, k=max_k, dim=0)
viterbi_paths = []
for i in range(max_k):
viterbi_path = [best_paths[i]]
for backward_timestep in reversed(path_indices):
# Reverse the backward path.
# Viterbi paths uses (num_tags * n_permutations) nodes; therefore, we need to modulo.
viterbi_path = [j % num_tags for j in viterbi_path]
return viterbi_paths, viterbi_scores
# Testing
from torch.autograd import Variable
from tqdm import tqdm
import random
import numpy as np
def test_greedy():
# Test Viterbi decoding is equal to greedy decoding with no pairwise potentials.
sequence_logits = Variable(torch.rand([5, 9]))
transition_matrix = torch.zeros([9, 9])
indices, _ = viterbi_decode(, transition_matrix)
_, argmax_indices = torch.max(sequence_logits, 1)
assert indices[0] ==
def test_inf():
# Test that pairwise potentials effect the sequence correctly and that
# viterbi_decode can handle -inf values.
sequence_logits = torch.FloatTensor([[0, 0, 0, 3, 4], [0, 0, 0, 3, 4], [0, 0, 0, 3, 4],
[0, 0, 0, 3, 4], [0, 0, 0, 3, 4], [0, 0, 0, 3, 4]])
# The same tags shouldn't appear sequentially.
transition_matrix = torch.zeros([5, 5])
for i in range(5):
transition_matrix[i, i] = float("-inf")
indices, _ = viterbi_decode(sequence_logits, transition_matrix)
assert indices[0] == [3, 4, 3, 4, 3, 4]
def test_ties():
# Test that unbalanced pairwise potentials break ties
# between paths with equal unary potentials.
sequence_logits = torch.FloatTensor([[0, 0, 0, 4, 4], [0, 0, 0, 4, 4], [0, 0, 0, 4, 4],
[0, 0, 0, 4, 4], [0, 0, 0, 4, 4], [0, 0, 0, 4, 4]])
# The 5th tag has a penalty for appearing sequentially
# or for transitioning to the 4th tag, making the best
# path uniquely to take the 4th tag only.
transition_matrix = torch.zeros([5, 5])
transition_matrix[4, 4] = -10
transition_matrix[4, 3] = -10
indices, _ = viterbi_decode(sequence_logits, transition_matrix)
assert indices[0] == [3, 3, 3, 3, 3, 3]
def test_transitions():
sequence_logits = torch.FloatTensor([[1, 0, 0, 4], [1, 0, 6, 2], [0, 3, 0, 4]])
# Best path would normally be [3, 2, 3] but we add a
# potential from 2 -> 1, making [3, 2, 1] the best path.
transition_matrix = torch.zeros([4, 4])
transition_matrix[0, 0] = 1
transition_matrix[2, 1] = 5
indices, value = viterbi_decode(sequence_logits, transition_matrix)
assert indices[0] == [3, 2, 1]
assert value[0] == 18
# Use the brute decoding as truth
def brute_decode(tag_sequence: torch.Tensor, transition_matrix: torch.Tensor, top_k: int=5):
Top-k decoder that uses brute search instead of the Viterbi Decode dynamic programing algorithm
# Create all possible sequences
sequence_length, num_tags = list(tag_sequence.size())
sequences = [[]]
for i in range(len(tag_sequence)):
new_sequences = []
for j in range(len(tag_sequence[i])):
for sequence in sequences:
new_sequences.append(sequence[:] + [j])
sequences = new_sequences
# Score
scored_sequences = []
for sequence in sequences:
emission_score = sum([tag_sequence[i, j] for i, j in enumerate(sequence)])
transition_score = sum(
[transition_matrix[sequence[i - 1], sequence[i]] for i in range(1, len(sequence))])
score = emission_score + transition_score
scored_sequences.append((score, sequence))
# Get the top k scores / paths
top_k_sequences = sorted(scored_sequences, key=lambda r: r[0], reverse=True)[:top_k]
scores, paths = zip(*top_k_sequences)
return paths, scores
def test_brute():
# Run 100 randomly generated parameters and compare the outputs.
for i in tqdm(range(100)):
num_tags = random.randint(1, 5)
seq_len = random.randint(1, 5)
k = random.randint(1, 5)
sequence_logits = torch.rand([seq_len, num_tags])
transition_matrix = torch.rand([num_tags, num_tags])
viterbi_paths_v1, viterbi_scores_v1 = viterbi_decode(
sequence_logits, transition_matrix, top_k=k)
viterbi_path_brute, viterbi_score_brute = brute_decode(
sequence_logits, transition_matrix, top_k=k)
list(viterbi_score_brute), viterbi_scores_v1.tolist(), decimal=3)
Copy link

Nice! Thanks for sharing. What license does this code have? Apache 2.0 like AllenNLP?

Copy link

jind11 commented Aug 27, 2018

Is there a version for batch processing instead of only processing one sequence at one time?

Copy link

kroegern1 commented Jun 17, 2022

I run this code and the Ties test fails... is this supposed to pass?

print(indices) before the assert statement in test_ties() returns
[[3, 3, 3, 3, 3, tensor(4)], [3, 3, 3, 3, 3, tensor(3)], [3, 3, 3, 3, 4, tensor(2)], [3, 3, 3, 3, 4, tensor(1)], [3, 3, 3, 3, 4, tensor(0)]]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment