Skip to content

Instantly share code, notes, and snippets.

Created October 16, 2019 09:09
Show Gist options
  • Save yzhangcs/62efc71eef8f8d7eb541158008f52894 to your computer and use it in GitHub Desktop.
Save yzhangcs/62efc71eef8f8d7eb541158008f52894 to your computer and use it in GitHub Desktop.
The implementation of inside&outside and eisner algorithms
# -*- coding: utf-8 -*-
import torch
from torch.nn.utils.rnn import pad_sequence
def crf(scores, mask, target=None, partial=False):
lens = mask.sum(1)
total = lens.sum()
batch_size, seq_len, _ = scores.shape
training = scores.requires_grad
# always enable the gradient computation of scores
# in order for the computation of marginal probs
s_ii, s_ic = inside(scores.requires_grad_(), mask)
logZ = s_ic[0].gather(0, lens.unsqueeze(0)).sum()
# use pseudo answers if target is None in the prediction phrase
if target is None:
target = lens.new_zeors(batch_size, seq_len)
# the second inside process is needed if use partial annotation
if partial:
total = total -
s_ii, s_ic = inside(scores, mask, target)
score = s_ic[0].gather(0, lens.unsqueeze(0)).sum()
score = scores.gather(-1, target.unsqueeze(-1)).squeeze(-1)[mask].sum()
loss = (logZ - score) / total
if training:
return loss, None
# marginal probs are used for decoding, and can be computed by
# combining the inside algorithm and autograd mechanism
# instead of the entire inside-outside process
target_mask = target.unsqueeze(-1).eq(lens.new_tensor(range(seq_len)))
marginal_probs = scores.grad * total + target_mask.float()
return loss, marginal_probs
def inside(scores, mask, candidates=None):
# the end position of each sentence in a batch
lens = mask.sum(1)
batch_size, seq_len, _ = scores.shape
# [seq_len, seq_len, batch_size]
scores = scores.permute(2, 1, 0)
# include the first token of each sentence
mask = mask.index_fill(1, lens.new_tensor(0), 1)
# [seq_len, seq_len, batch_size]
mask = (mask.unsqueeze(1) & mask.unsqueeze(-1)).permute(2, 1, 0)
s_ii = torch.full_like(scores, float('-inf'))
s_ic = torch.full_like(scores, float('-inf'))
# set the scores of arcs excluded by candidates to -inf
if candidates is not None:
heads = candidates.unsqueeze(-1)
heads = heads.index_fill_(1, lens.new_tensor(0), -1)
candidates = heads.eq(lens.new_tensor(range(seq_len))) |
candidates = candidates.permute(2, 1, 0) & mask
scores = scores.masked_fill(~candidates, float('-inf'))
for w in range(1, seq_len):
# n denotes the number of spans to iterate,
# from span (0, w) to span (n, n+w) given width w
n = seq_len - w
# diag_mask is used for ignoring the excess of each sentence
# [batch_size, n]
cand_mask = diag_mask = mask.diagonal(w)
# ilr = C(i, r) + C(j, r+1)
# [n, w, batch_size]
ilr = stripe(s_ic, n, w) + stripe(s_ic, n, w, (w, 1))
if candidates is not None:
cand_mask = torch.isfinite(ilr).any(1).t() & diag_mask
ilr = ilr.permute(2, 0, 1)[cand_mask].logsumexp(-1)
# I(j, i) = logsumexp(C(i, r) + C(j, r+1)) + S(j, i), i <= r < j
il = ilr + scores.diagonal(-w)[cand_mask]
# fill the w-th diagonal of the lower triangular part of s_ii
# with I(j, i) of n spans
s_ii.diagonal(-w)[cand_mask] = il
# I(i, j) = logsumexp(C(i, r) + C(j, r+1)) + S(i, j), i <= r < j
ir = ilr + scores.diagonal(w)[cand_mask]
# fill the w-th diagonal of the upper triangular part of s_ii
# with I(i, j) of n spans
s_ii.diagonal(w)[cand_mask] = ir
# C(j, i) = logsumexp(C(r, i) + I(j, r)), i <= r < j
cl = stripe(s_ic, n, w, dim=0) + stripe(s_ii, n, w, (w, 0))
if candidates is not None:
cand_mask = torch.isfinite(cl).any(1).t() & diag_mask
cl = cl.permute(2, 0, 1)[cand_mask].logsumexp(-1)
s_ic.diagonal(-w)[cand_mask] = cl
# C(i, j) = logsumexp(I(i, r) + C(r, j)), i < r <= j
cr = stripe(s_ii, n, w, (0, 1)) + stripe(s_ic, n, w, (1, w), 0)
if candidates is not None:
cand_mask = torch.isfinite(cr).any(1).t() & diag_mask
cr = cr.permute(2, 0, 1)[cand_mask].logsumexp(-1)
s_ic.diagonal(w)[cand_mask] = cr
# disable multi words to modify the root node
s_ic[0, w][] = float('-inf')
return s_ii, s_ic
def outside(scores, mask):
s_ii, s_ic = inside(scores, mask)
# the end position of each sentence in a batch
lens = mask.sum(1)
batch_size, seq_len, _ = scores.shape
# [seq_len, seq_len, batch_size]
scores = scores.permute(2, 1, 0)
# include the first token of each sentence
mask = mask.index_fill(1, lens.new_tensor(0), 1)
# [seq_len, seq_len, batch_size]
mask = (mask.unsqueeze(1) & mask.unsqueeze(-1)).permute(2, 1, 0)
s_oi = torch.full_like(scores, float('-inf'))
s_oc = torch.full_like(scores, float('-inf'))
s_oi[0].scatter_(0, lens.unsqueeze(0), s_ic[lens, lens])
s_oc[0].scatter_(0, lens.unsqueeze(0), 0)
for w in reversed(range(seq_len - 1)):
n = seq_len - w
diag_mask = mask.diagonal(w) &
# [n, n, batch_size]
# II(r, j), OC(r, i), j < r < N
iil = triu(s_ii, n, n - 1, (w+1, w), 0)
ocl = triu(s_oc, n, n - 1, (w+1, 0), 0)
# IC(r, i - 1), OI(j, r), OI(r, j), 0 <= r < i
icr = tril(s_ic, n, n - 1, dim=0)
oil = tril(s_oi, n, n - 1, (w+1, 0))
oir = tril(s_oi, n, n - 1, (0, w+1), 0)
# S(j, r), S(r, j), 0 <= r < i
sl = tril(scores, n, n - 1, (w+1, 0))
sr = tril(scores, n, n - 1, (0, w+1), 0)
# cr = logsumexp(II(r, j) + OC(r, i)), j < r < N
cr = (iil + ocl).permute(2, 0, 1).logsumexp(-1)
# cll = logsumexp(IC(r, i - 1) + OI(j, r)) + S(j, r), 0 <= r < i
cll = (icr + oil + sl).permute(2, 0, 1).logsumexp(-1)
# clr = logsumexp(IC(r, i - 1) + OI(r, j)) + S(r, j), 0 <= r < i
clr = (icr + oir + sr).permute(2, 0, 1).logsumexp(-1)
# [batch_size, n]
ocl = torch.stack((cr, cll, clr), -1).logsumexp(-1)
# OC[j, i] = logsumexp(cr + cll + clr)
s_oc.diagonal(-w)[diag_mask] = ocl[diag_mask]
# [n, n, batch_size]
# II(r, i), OC(r, j), 0 <= r < i
iir = tril(s_ii, n, n - 1, (0, 1), 0)
ocr = tril(s_oc, n, n - 1, (0, w+1), 0)
# IC(r, j + 1), OI(r, i), OI(i, r), j < r < N
icl = triu(s_ic, n, n - 1, (w+1, w+1), 0)
oil = triu(s_oi, n, n - 1, (w+1, 0), 0)
oir = triu(s_oi, n, n - 1, (0, w+1))
# S(r, i), S(i, r), j < r < N
sl = triu(scores, n, n - 1, (w+1, 0), 0)
sr = triu(scores, n, n - 1, (0, w+1))
# cl = logsumexp(II(r, i) + OC(r, j)), 0 <= r < i
cl = (iir + ocr).permute(2, 0, 1).logsumexp(-1)
# crl = logsumexp(IC(r, j + 1) + OI(r, i)) + S(r, i), j < r < N
crl = (icl + oil + sl).permute(2, 0, 1).logsumexp(-1)
# crr = logsumexp(IC(r, j + 1) + OI(i, r)) + S(i, r), j < r < N
crr = (icl + oir + sr).permute(2, 0, 1).logsumexp(-1)
# [batch_size, n]
ocr = torch.stack((cl, crl, crr), -1).logsumexp(-1)
if w != 0:
ocr[:, 0] = float('-inf')
# OC[i, j] = logsumexp(cl + crr + crl)
s_oc.diagonal(w)[diag_mask] = ocr[diag_mask]
# [n, n, batch_size]
# IC(i, r), OC(j, r), 0 <= r <= i
icl = tril(s_ic, n, n, (0, 0))
ocl = tril(s_oc, n, n, (w, 0))
# IC(j, r), OC(i, r), j <= r < N
icr = triu(s_ic, n, n, (w, w))
ocr = triu(s_oc, n, n, (0, w))
# OI[j, i] = logsumexp(IC(i, r) + OC(j, r)), 0 <= r <= i
oil = (icl + ocl).permute(2, 0, 1).logsumexp(-1)
s_oi.diagonal(-w)[diag_mask] = oil[diag_mask]
# OI[i, j] = logsumexp(IC(j, r) + OC(i, r)), j <= r < N
oir = (icr + ocr).permute(2, 0, 1).logsumexp(-1)
s_oi.diagonal(w)[diag_mask] = oir[diag_mask]
return s_ii, s_ic, s_oi, s_oc
def eisner(scores, mask):
lens = mask.sum(1)
batch_size, seq_len, _ = scores.shape
scores = scores.permute(2, 1, 0)
s_i = torch.full_like(scores, float('-inf'))
s_c = torch.full_like(scores, float('-inf'))
p_i = scores.new_zeros(seq_len, seq_len, batch_size).long()
p_c = scores.new_zeros(seq_len, seq_len, batch_size).long()
for w in range(1, seq_len):
n = seq_len - w
starts = p_i.new_tensor(range(n)).unsqueeze(0)
# ilr = C(i, r) + C(j, r+1)
ilr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1))
# [batch_size, n, w]
ilr = ilr.permute(2, 0, 1)
il = ilr + scores.diagonal(-w).unsqueeze(-1)
# I(j, i) = max(C(i, r) + C(j, r+1) + S(j, i)), i <= r < j
il_span, il_path = il.max(-1)
p_i.diagonal(-w).copy_(il_path + starts)
ir = ilr + scores.diagonal(w).unsqueeze(-1)
# I(i, j) = max(C(i, r) + C(j, r+1) + S(i, j)), i <= r < j
ir_span, ir_path = ir.max(-1)
p_i.diagonal(w).copy_(ir_path + starts)
# C(j, i) = max(C(r, i) + I(j, r)), i <= r < j
cl = stripe(s_c, n, w, dim=0) + stripe(s_i, n, w, (w, 0))
cl_span, cl_path = cl.permute(2, 0, 1).max(-1)
p_c.diagonal(-w).copy_(cl_path + starts)
# C(i, j) = max(I(i, r) + C(r, j)), i < r <= j
cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0)
cr_span, cr_path = cr.permute(2, 0, 1).max(-1)
s_c[0, w][] = float('-inf')
p_c.diagonal(w).copy_(cr_path + starts + 1)
predicts = []
p_c = p_c.permute(2, 0, 1).cpu()
p_i = p_i.permute(2, 0, 1).cpu()
for i, length in enumerate(lens.tolist()):
heads = p_c.new_ones(length + 1, dtype=torch.long)
backtrack(p_i[i], p_c[i], heads, 0, length, True)
return pad_sequence(predicts, True)
def backtrack(p_i, p_c, heads, i, j, complete):
if i == j:
if complete:
r = p_c[i, j]
backtrack(p_i, p_c, heads, i, r, False)
backtrack(p_i, p_c, heads, r, j, True)
r, heads[j] = p_i[i, j], i
i, j = sorted((i, j))
backtrack(p_i, p_c, heads, i, r, True)
backtrack(p_i, p_c, heads, j, r + 1, True)
def stripe(x, n, w, offset=(0, 0), dim=1):
r'''Returns a diagonal stripe of the tensor.
x (Tensor): the input tensor with 2 or more dims.
n (int): the length of the stripe.
w (int): the width of the stripe.
offset (tuple): the offset of the first two dims.
dim (int): 0 if returns a horizontal stripe; 1 else.
>>> x = torch.arange(25).view(5, 5)
>>> x
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]])
>>> stripe(x, 2, 3, (1, 1))
tensor([[ 6, 7, 8],
[12, 13, 14]])
>>> stripe(x, 2, 3, dim=0)
tensor([[ 0, 5, 10],
[ 6, 11, 16]])
seq_len = x.size(1)
stride, numel = list(x.stride()), x[0, 0].numel()
stride[0] = (seq_len + 1) * numel
stride[1] = (1 if dim == 1 else seq_len) * numel
return x.as_strided(size=(n, w, *x.shape[2:]),
def tril(x, n, w, offset=(0, 0), dim=1):
x = x[offset[0]:offset[0]+w, offset[-1]:offset[-1]+w]
mask = x.new_ones(w, w).tril_().eq(1)
tri = x.new_full((n, n, *x.shape[2:]), float('-inf'))
if dim == 0:
x = x.transpose(0, 1)
tri[-w:, :w][mask] = x[mask]
return tri
def triu(x, n, w, offset=(0, 0), dim=1):
x = x[offset[0]:offset[0]+w, offset[-1]:offset[-1]+w]
mask = x.new_ones(w, w).triu_().eq(1)
tri = x.new_full((n, n, *x.shape[2:]), float('-inf'))
if dim == 0:
x = x.transpose(0, 1)
tri[:w, :w][mask] = x[mask]
return tri
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment