Created
October 16, 2019 09:09
-
-
Save yzhangcs/62efc71eef8f8d7eb541158008f52894 to your computer and use it in GitHub Desktop.
The implementation of inside&outside and eisner algorithms
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
@torch.enable_grad() | |
def crf(scores, mask, target=None, partial=False): | |
lens = mask.sum(1) | |
total = lens.sum() | |
batch_size, seq_len, _ = scores.shape | |
training = scores.requires_grad | |
# always enable the gradient computation of scores | |
# in order for the computation of marginal probs | |
s_ii, s_ic = inside(scores.requires_grad_(), mask) | |
logZ = s_ic[0].gather(0, lens.unsqueeze(0)).sum() | |
# use pseudo answers if target is None in the prediction phrase | |
if target is None: | |
target = lens.new_zeors(batch_size, seq_len) | |
# the second inside process is needed if use partial annotation | |
if partial: | |
total = total - target.lt(0).sum() | |
s_ii, s_ic = inside(scores, mask, target) | |
score = s_ic[0].gather(0, lens.unsqueeze(0)).sum() | |
else: | |
score = scores.gather(-1, target.unsqueeze(-1)).squeeze(-1)[mask].sum() | |
loss = (logZ - score) / total | |
if training: | |
return loss, None | |
# marginal probs are used for decoding, and can be computed by | |
# combining the inside algorithm and autograd mechanism | |
# instead of the entire inside-outside process | |
loss.backward() | |
target_mask = target.unsqueeze(-1).eq(lens.new_tensor(range(seq_len))) | |
marginal_probs = scores.grad * total + target_mask.float() | |
return loss, marginal_probs | |
def inside(scores, mask, candidates=None): | |
# the end position of each sentence in a batch | |
lens = mask.sum(1) | |
batch_size, seq_len, _ = scores.shape | |
# [seq_len, seq_len, batch_size] | |
scores = scores.permute(2, 1, 0) | |
# include the first token of each sentence | |
mask = mask.index_fill(1, lens.new_tensor(0), 1) | |
# [seq_len, seq_len, batch_size] | |
mask = (mask.unsqueeze(1) & mask.unsqueeze(-1)).permute(2, 1, 0) | |
s_ii = torch.full_like(scores, float('-inf')) | |
s_ic = torch.full_like(scores, float('-inf')) | |
s_ic.diagonal().fill_(0) | |
# set the scores of arcs excluded by candidates to -inf | |
if candidates is not None: | |
heads = candidates.unsqueeze(-1) | |
heads = heads.index_fill_(1, lens.new_tensor(0), -1) | |
candidates = heads.eq(lens.new_tensor(range(seq_len))) | heads.lt(0) | |
candidates = candidates.permute(2, 1, 0) & mask | |
scores = scores.masked_fill(~candidates, float('-inf')) | |
for w in range(1, seq_len): | |
# n denotes the number of spans to iterate, | |
# from span (0, w) to span (n, n+w) given width w | |
n = seq_len - w | |
# diag_mask is used for ignoring the excess of each sentence | |
# [batch_size, n] | |
cand_mask = diag_mask = mask.diagonal(w) | |
# ilr = C(i, r) + C(j, r+1) | |
# [n, w, batch_size] | |
ilr = stripe(s_ic, n, w) + stripe(s_ic, n, w, (w, 1)) | |
if candidates is not None: | |
cand_mask = torch.isfinite(ilr).any(1).t() & diag_mask | |
ilr = ilr.permute(2, 0, 1)[cand_mask].logsumexp(-1) | |
# I(j, i) = logsumexp(C(i, r) + C(j, r+1)) + S(j, i), i <= r < j | |
il = ilr + scores.diagonal(-w)[cand_mask] | |
# fill the w-th diagonal of the lower triangular part of s_ii | |
# with I(j, i) of n spans | |
s_ii.diagonal(-w)[cand_mask] = il | |
# I(i, j) = logsumexp(C(i, r) + C(j, r+1)) + S(i, j), i <= r < j | |
ir = ilr + scores.diagonal(w)[cand_mask] | |
# fill the w-th diagonal of the upper triangular part of s_ii | |
# with I(i, j) of n spans | |
s_ii.diagonal(w)[cand_mask] = ir | |
# C(j, i) = logsumexp(C(r, i) + I(j, r)), i <= r < j | |
cl = stripe(s_ic, n, w, dim=0) + stripe(s_ii, n, w, (w, 0)) | |
if candidates is not None: | |
cand_mask = torch.isfinite(cl).any(1).t() & diag_mask | |
cl = cl.permute(2, 0, 1)[cand_mask].logsumexp(-1) | |
s_ic.diagonal(-w)[cand_mask] = cl | |
# C(i, j) = logsumexp(I(i, r) + C(r, j)), i < r <= j | |
cr = stripe(s_ii, n, w, (0, 1)) + stripe(s_ic, n, w, (1, w), 0) | |
if candidates is not None: | |
cand_mask = torch.isfinite(cr).any(1).t() & diag_mask | |
cr = cr.permute(2, 0, 1)[cand_mask].logsumexp(-1) | |
s_ic.diagonal(w)[cand_mask] = cr | |
# disable multi words to modify the root node | |
s_ic[0, w][lens.ne(w)] = float('-inf') | |
return s_ii, s_ic | |
def outside(scores, mask): | |
s_ii, s_ic = inside(scores, mask) | |
# the end position of each sentence in a batch | |
lens = mask.sum(1) | |
batch_size, seq_len, _ = scores.shape | |
# [seq_len, seq_len, batch_size] | |
scores = scores.permute(2, 1, 0) | |
# include the first token of each sentence | |
mask = mask.index_fill(1, lens.new_tensor(0), 1) | |
# [seq_len, seq_len, batch_size] | |
mask = (mask.unsqueeze(1) & mask.unsqueeze(-1)).permute(2, 1, 0) | |
s_oi = torch.full_like(scores, float('-inf')) | |
s_oc = torch.full_like(scores, float('-inf')) | |
s_oi[0].scatter_(0, lens.unsqueeze(0), s_ic[lens, lens]) | |
s_oc[0].scatter_(0, lens.unsqueeze(0), 0) | |
for w in reversed(range(seq_len - 1)): | |
n = seq_len - w | |
diag_mask = mask.diagonal(w) & lens.gt(w).unsqueeze(-1) | |
# [n, n, batch_size] | |
# II(r, j), OC(r, i), j < r < N | |
iil = triu(s_ii, n, n - 1, (w+1, w), 0) | |
ocl = triu(s_oc, n, n - 1, (w+1, 0), 0) | |
# IC(r, i - 1), OI(j, r), OI(r, j), 0 <= r < i | |
icr = tril(s_ic, n, n - 1, dim=0) | |
oil = tril(s_oi, n, n - 1, (w+1, 0)) | |
oir = tril(s_oi, n, n - 1, (0, w+1), 0) | |
# S(j, r), S(r, j), 0 <= r < i | |
sl = tril(scores, n, n - 1, (w+1, 0)) | |
sr = tril(scores, n, n - 1, (0, w+1), 0) | |
# cr = logsumexp(II(r, j) + OC(r, i)), j < r < N | |
cr = (iil + ocl).permute(2, 0, 1).logsumexp(-1) | |
# cll = logsumexp(IC(r, i - 1) + OI(j, r)) + S(j, r), 0 <= r < i | |
cll = (icr + oil + sl).permute(2, 0, 1).logsumexp(-1) | |
# clr = logsumexp(IC(r, i - 1) + OI(r, j)) + S(r, j), 0 <= r < i | |
clr = (icr + oir + sr).permute(2, 0, 1).logsumexp(-1) | |
# [batch_size, n] | |
ocl = torch.stack((cr, cll, clr), -1).logsumexp(-1) | |
# OC[j, i] = logsumexp(cr + cll + clr) | |
s_oc.diagonal(-w)[diag_mask] = ocl[diag_mask] | |
# [n, n, batch_size] | |
# II(r, i), OC(r, j), 0 <= r < i | |
iir = tril(s_ii, n, n - 1, (0, 1), 0) | |
ocr = tril(s_oc, n, n - 1, (0, w+1), 0) | |
# IC(r, j + 1), OI(r, i), OI(i, r), j < r < N | |
icl = triu(s_ic, n, n - 1, (w+1, w+1), 0) | |
oil = triu(s_oi, n, n - 1, (w+1, 0), 0) | |
oir = triu(s_oi, n, n - 1, (0, w+1)) | |
# S(r, i), S(i, r), j < r < N | |
sl = triu(scores, n, n - 1, (w+1, 0), 0) | |
sr = triu(scores, n, n - 1, (0, w+1)) | |
# cl = logsumexp(II(r, i) + OC(r, j)), 0 <= r < i | |
cl = (iir + ocr).permute(2, 0, 1).logsumexp(-1) | |
# crl = logsumexp(IC(r, j + 1) + OI(r, i)) + S(r, i), j < r < N | |
crl = (icl + oil + sl).permute(2, 0, 1).logsumexp(-1) | |
# crr = logsumexp(IC(r, j + 1) + OI(i, r)) + S(i, r), j < r < N | |
crr = (icl + oir + sr).permute(2, 0, 1).logsumexp(-1) | |
# [batch_size, n] | |
ocr = torch.stack((cl, crl, crr), -1).logsumexp(-1) | |
if w != 0: | |
ocr[:, 0] = float('-inf') | |
# OC[i, j] = logsumexp(cl + crr + crl) | |
s_oc.diagonal(w)[diag_mask] = ocr[diag_mask] | |
# [n, n, batch_size] | |
# IC(i, r), OC(j, r), 0 <= r <= i | |
icl = tril(s_ic, n, n, (0, 0)) | |
ocl = tril(s_oc, n, n, (w, 0)) | |
# IC(j, r), OC(i, r), j <= r < N | |
icr = triu(s_ic, n, n, (w, w)) | |
ocr = triu(s_oc, n, n, (0, w)) | |
# OI[j, i] = logsumexp(IC(i, r) + OC(j, r)), 0 <= r <= i | |
oil = (icl + ocl).permute(2, 0, 1).logsumexp(-1) | |
s_oi.diagonal(-w)[diag_mask] = oil[diag_mask] | |
# OI[i, j] = logsumexp(IC(j, r) + OC(i, r)), j <= r < N | |
oir = (icr + ocr).permute(2, 0, 1).logsumexp(-1) | |
s_oi.diagonal(w)[diag_mask] = oir[diag_mask] | |
return s_ii, s_ic, s_oi, s_oc | |
def eisner(scores, mask): | |
lens = mask.sum(1) | |
batch_size, seq_len, _ = scores.shape | |
scores = scores.permute(2, 1, 0) | |
s_i = torch.full_like(scores, float('-inf')) | |
s_c = torch.full_like(scores, float('-inf')) | |
p_i = scores.new_zeros(seq_len, seq_len, batch_size).long() | |
p_c = scores.new_zeros(seq_len, seq_len, batch_size).long() | |
s_c.diagonal().fill_(0) | |
for w in range(1, seq_len): | |
n = seq_len - w | |
starts = p_i.new_tensor(range(n)).unsqueeze(0) | |
# ilr = C(i, r) + C(j, r+1) | |
ilr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1)) | |
# [batch_size, n, w] | |
ilr = ilr.permute(2, 0, 1) | |
il = ilr + scores.diagonal(-w).unsqueeze(-1) | |
# I(j, i) = max(C(i, r) + C(j, r+1) + S(j, i)), i <= r < j | |
il_span, il_path = il.max(-1) | |
s_i.diagonal(-w).copy_(il_span) | |
p_i.diagonal(-w).copy_(il_path + starts) | |
ir = ilr + scores.diagonal(w).unsqueeze(-1) | |
# I(i, j) = max(C(i, r) + C(j, r+1) + S(i, j)), i <= r < j | |
ir_span, ir_path = ir.max(-1) | |
s_i.diagonal(w).copy_(ir_span) | |
p_i.diagonal(w).copy_(ir_path + starts) | |
# C(j, i) = max(C(r, i) + I(j, r)), i <= r < j | |
cl = stripe(s_c, n, w, dim=0) + stripe(s_i, n, w, (w, 0)) | |
cl_span, cl_path = cl.permute(2, 0, 1).max(-1) | |
s_c.diagonal(-w).copy_(cl_span) | |
p_c.diagonal(-w).copy_(cl_path + starts) | |
# C(i, j) = max(I(i, r) + C(r, j)), i < r <= j | |
cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0) | |
cr_span, cr_path = cr.permute(2, 0, 1).max(-1) | |
s_c.diagonal(w).copy_(cr_span) | |
s_c[0, w][lens.ne(w)] = float('-inf') | |
p_c.diagonal(w).copy_(cr_path + starts + 1) | |
predicts = [] | |
p_c = p_c.permute(2, 0, 1).cpu() | |
p_i = p_i.permute(2, 0, 1).cpu() | |
for i, length in enumerate(lens.tolist()): | |
heads = p_c.new_ones(length + 1, dtype=torch.long) | |
backtrack(p_i[i], p_c[i], heads, 0, length, True) | |
predicts.append(heads.to(mask.device)) | |
return pad_sequence(predicts, True) | |
def backtrack(p_i, p_c, heads, i, j, complete): | |
if i == j: | |
return | |
if complete: | |
r = p_c[i, j] | |
backtrack(p_i, p_c, heads, i, r, False) | |
backtrack(p_i, p_c, heads, r, j, True) | |
else: | |
r, heads[j] = p_i[i, j], i | |
i, j = sorted((i, j)) | |
backtrack(p_i, p_c, heads, i, r, True) | |
backtrack(p_i, p_c, heads, j, r + 1, True) | |
def stripe(x, n, w, offset=(0, 0), dim=1): | |
r'''Returns a diagonal stripe of the tensor. | |
Parameters: | |
x (Tensor): the input tensor with 2 or more dims. | |
n (int): the length of the stripe. | |
w (int): the width of the stripe. | |
offset (tuple): the offset of the first two dims. | |
dim (int): 0 if returns a horizontal stripe; 1 else. | |
Example:: | |
>>> x = torch.arange(25).view(5, 5) | |
>>> x | |
tensor([[ 0, 1, 2, 3, 4], | |
[ 5, 6, 7, 8, 9], | |
[10, 11, 12, 13, 14], | |
[15, 16, 17, 18, 19], | |
[20, 21, 22, 23, 24]]) | |
>>> stripe(x, 2, 3, (1, 1)) | |
tensor([[ 6, 7, 8], | |
[12, 13, 14]]) | |
>>> stripe(x, 2, 3, dim=0) | |
tensor([[ 0, 5, 10], | |
[ 6, 11, 16]]) | |
''' | |
seq_len = x.size(1) | |
stride, numel = list(x.stride()), x[0, 0].numel() | |
stride[0] = (seq_len + 1) * numel | |
stride[1] = (1 if dim == 1 else seq_len) * numel | |
return x.as_strided(size=(n, w, *x.shape[2:]), | |
stride=stride, | |
storage_offset=(offset[0]*seq_len+offset[1])*numel) | |
def tril(x, n, w, offset=(0, 0), dim=1): | |
x = x[offset[0]:offset[0]+w, offset[-1]:offset[-1]+w] | |
mask = x.new_ones(w, w).tril_().eq(1) | |
tri = x.new_full((n, n, *x.shape[2:]), float('-inf')) | |
if dim == 0: | |
x = x.transpose(0, 1) | |
tri[-w:, :w][mask] = x[mask] | |
return tri | |
def triu(x, n, w, offset=(0, 0), dim=1): | |
x = x[offset[0]:offset[0]+w, offset[-1]:offset[-1]+w] | |
mask = x.new_ones(w, w).triu_().eq(1) | |
tri = x.new_full((n, n, *x.shape[2:]), float('-inf')) | |
if dim == 0: | |
x = x.transpose(0, 1) | |
tri[:w, :w][mask] = x[mask] | |
return tri |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment