Skip to content

Instantly share code, notes, and snippets.

@shawntan
Created November 26, 2019 13:20
Show Gist options
  • Save shawntan/b3131606240c6eb1bb9f8477cd12ad76 to your computer and use it in GitHub Desktop.
Save shawntan/b3131606240c6eb1bb9f8477cd12ad76 to your computer and use it in GitHub Desktop.
# coding: utf-8
import torch
from torch import nn
import hinton
# import numpy as np
DEBUG = False
torch.autograd.set_detect_anomaly(DEBUG)
if DEBUG:
torch.set_printoptions(precision=2)
def hintplot(x):
if DEBUG:
print(hinton.plot(x.detach().cpu().numpy(), 1))
def extract_label_log_probs(log_probs, labels):
# print("lob_probs", log_probs.size(), "labels", labels.size())
return log_probs[:, torch.arange(labels.size(1))[:, None], labels.t()]
def augment_labels(labels, blank_idx):
blanks = torch.full_like(labels, blank_idx)
if False:
labels_blanks = torch.stack((labels, blanks), dim=1).flatten(0, 1)
else:
labels_blanks = torch.cat([
blanks[-1:],
torch.stack((labels, blanks), dim=1).flatten(0, 1)
], dim=0)
if DEBUG:
pass
#print(labels_blanks)
return labels_blanks
def create_transition_matrix(target_length):
transition = torch.zeros(target_length, target_length, dtype=torch.float)
idxs = torch.arange(target_length, dtype=torch.long)
i = torch.tensor(1.)
transition[idxs[::2], idxs[::2]] = i # blank to blank (across)
transition[idxs[:-1], idxs[1:]] = i # blank to word, word to blank
transition[idxs[1:-3:2], idxs[3:-1:2]] = i # word to word (skip a blank)
return transition
def masked_logsumexp(log_p, mask):
log_p = log_p.masked_fill(mask, log_p.min())
k = torch.max(log_p, dim=1, keepdim=True)[0]
exp_ = torch.exp(log_p - k).masked_fill(mask, 0.)
sum_exp = torch.sum(exp_, dim=1, keepdim=True)
return torch.log(sum_exp) + k
@torch.jit.script
def forward_ctc_exp(extracted_log_probs, target_lengths, transition,
# lwc_depths=None, rwc_depths=None,
eps=torch.tensor(0.),
log_eps=torch.tensor(float('-inf'))):
T, _, max_length = extracted_log_probs.size()
acc = torch.zeros_like(extracted_log_probs[0, :, 0])
prev_probs = torch.zeros_like(extracted_log_probs[0, :])
prev_probs[:, 0] = 1.
idxs = torch.arange(max_length,
dtype=torch.long,
device=target_lengths.device)
idxs_t = torch.arange(T, dtype=torch.long,
device=target_lengths.device)
reach_mask_start = idxs[None, :] >= 2 * (idxs_t[:, None] + 1)
reach_mask_end = (
idxs[None, None, :] <
target_lengths[None, :, None] - 2 * (T - idxs_t[:, None, None])
)
length_mask = idxs[None, :] >= target_lengths[:, None]
mask = (reach_mask_end |
reach_mask_start[:, None, :] |
length_mask[None, :, :])
"""
# print(idxs[None, None, :].size(), rwc_depths[:, :, None].size())
end_worst_case_mask = idxs[None, None, :] < 2 * rwc_depths[:, :, None]
mask = mask | end_worst_case_mask
mask_idxs = target_lengths[None, :] - 2 * lwc_depths
start_worst_case_mask = idxs[None, None, :] > mask_idxs[:, :, None]
mask = mask | start_worst_case_mask
"""
output_length = extracted_log_probs.size(0)
for t in range(output_length):
# Transition
curr_probs = torch.matmul(prev_probs,
transition)
# Masks
t_mask = mask[t] | (curr_probs == 0.)
# Keeping it log-safe
log_curr_probs = torch.log(
curr_probs.masked_fill(t_mask, eps)
).masked_fill(t_mask, log_eps)
# print(extracted_log_probs[t, 0])
log_curr_probs = log_curr_probs + extracted_log_probs[t]
log_C = masked_logsumexp(log_curr_probs, t_mask)
# Keeping it exp-safe
log_curr_probs = log_curr_probs - log_C
exp_safe_mask = log_curr_probs < log_eps
prev_probs = torch.exp(
log_curr_probs.masked_fill(exp_safe_mask, log_eps)
).masked_fill(exp_safe_mask, 0.)
# if DEBUG:
# if max_length < 48:
# # print("%03d" % t, (~t_mask[0]).long())
# try:
# hintplot(prev_probs[0])
# except:
# print(prev_probs[0])
# exit()
# print(prev_probs[0])
acc += log_C[:, 0]
return acc
class Loss(nn.Module):
__constants__ = ['eps', 'log_eps', 'transition']
def __init__(self, blank,
transition_function=create_transition_matrix):
super(Loss, self).__init__()
self.blank_idx = blank
# self.eps = torch.tensor(0.)
self.eps = torch.tensor(1e-8)
# self.log_eps = torch.tensor(float("-inf"))
self.log_eps = torch.tensor(-64.)
self.max_length = 200
self.transition = transition_function(self.max_length)
self.first = True
def forward(self, log_probs, targets, target_lengths,
lwc_depths=None,
rwc_depths=None):
if self.first:
self.eps = self.eps.to(log_probs.device)
self.log_eps = self.log_eps.to(log_probs.device)
self.transition = self.transition.to(log_probs.device)
self.first = False
labels_blanks = augment_labels(targets, self.blank_idx)
extracted = extract_label_log_probs(log_probs, labels_blanks)
# extracted = torch.clamp(extracted, min=self.log_eps)
transition = self.transition[:extracted.size(-1), :extracted.size(-1)]
# lwc_depths = torch.min(lwc_depths, target_lengths[None, :] - 1)
# rwc_depths = torch.min(rwc_depths, target_lengths[None, :] - 1)
results = -forward_ctc_exp(extracted, (target_lengths * 2 + 1).long(),
transition,
# lwc_depths=lwc_depths,
# rwc_depths=rwc_depths,
eps=self.eps, log_eps=self.log_eps)
# results = -select_final_loss(acc_loss, target_lengths)
if DEBUG:
print(labels_blanks.size(), extracted.size())
print(results)
return results
if __name__ == "__main__":
# Test code.
torch.set_printoptions(precision=5)
log_probs = torch.log_softmax(torch.randn(32, 5, 11), dim=-1)
labels = torch.randint(0, 10, size=(5, 5))
seq_lengths = torch.randint(5, 6, size=(5,))
seq_lengths,_ = seq_lengths.sort(descending=True)
#log_probs = torch.arange(11).repeat(10, 5, 1)
print("log_probs", log_probs.size())
print("labels", labels.size())
ctc = Loss(-1)
log_calc = ctc(log_probs, labels, seq_lengths)
# Exp space calculation.
labels_blanks = augment_labels(labels, -1)
extracted_exp = extract_label_log_probs(log_probs, labels_blanks).exp()
transition = create_transition_matrix(extracted_exp.size(2))
prev_probs = torch.zeros_like(extracted_exp[0])
prev_probs[:, 0] = 1.
for i in range(log_probs.size(0)):
prev_probs = torch.matmul(prev_probs, transition) * extracted_exp[i]
exp_calc = -torch.log(prev_probs[:, -2:].sum(1))
print(log_calc, exp_calc)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment