Created
November 26, 2019 13:20
-
-
Save shawntan/b3131606240c6eb1bb9f8477cd12ad76 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
import torch | |
from torch import nn | |
import hinton | |
# import numpy as np | |
DEBUG = False | |
torch.autograd.set_detect_anomaly(DEBUG) | |
if DEBUG: | |
torch.set_printoptions(precision=2) | |
def hintplot(x): | |
if DEBUG: | |
print(hinton.plot(x.detach().cpu().numpy(), 1)) | |
def extract_label_log_probs(log_probs, labels): | |
# print("lob_probs", log_probs.size(), "labels", labels.size()) | |
return log_probs[:, torch.arange(labels.size(1))[:, None], labels.t()] | |
def augment_labels(labels, blank_idx): | |
blanks = torch.full_like(labels, blank_idx) | |
if False: | |
labels_blanks = torch.stack((labels, blanks), dim=1).flatten(0, 1) | |
else: | |
labels_blanks = torch.cat([ | |
blanks[-1:], | |
torch.stack((labels, blanks), dim=1).flatten(0, 1) | |
], dim=0) | |
if DEBUG: | |
pass | |
#print(labels_blanks) | |
return labels_blanks | |
def create_transition_matrix(target_length): | |
transition = torch.zeros(target_length, target_length, dtype=torch.float) | |
idxs = torch.arange(target_length, dtype=torch.long) | |
i = torch.tensor(1.) | |
transition[idxs[::2], idxs[::2]] = i # blank to blank (across) | |
transition[idxs[:-1], idxs[1:]] = i # blank to word, word to blank | |
transition[idxs[1:-3:2], idxs[3:-1:2]] = i # word to word (skip a blank) | |
return transition | |
def masked_logsumexp(log_p, mask): | |
log_p = log_p.masked_fill(mask, log_p.min()) | |
k = torch.max(log_p, dim=1, keepdim=True)[0] | |
exp_ = torch.exp(log_p - k).masked_fill(mask, 0.) | |
sum_exp = torch.sum(exp_, dim=1, keepdim=True) | |
return torch.log(sum_exp) + k | |
@torch.jit.script | |
def forward_ctc_exp(extracted_log_probs, target_lengths, transition, | |
# lwc_depths=None, rwc_depths=None, | |
eps=torch.tensor(0.), | |
log_eps=torch.tensor(float('-inf'))): | |
T, _, max_length = extracted_log_probs.size() | |
acc = torch.zeros_like(extracted_log_probs[0, :, 0]) | |
prev_probs = torch.zeros_like(extracted_log_probs[0, :]) | |
prev_probs[:, 0] = 1. | |
idxs = torch.arange(max_length, | |
dtype=torch.long, | |
device=target_lengths.device) | |
idxs_t = torch.arange(T, dtype=torch.long, | |
device=target_lengths.device) | |
reach_mask_start = idxs[None, :] >= 2 * (idxs_t[:, None] + 1) | |
reach_mask_end = ( | |
idxs[None, None, :] < | |
target_lengths[None, :, None] - 2 * (T - idxs_t[:, None, None]) | |
) | |
length_mask = idxs[None, :] >= target_lengths[:, None] | |
mask = (reach_mask_end | | |
reach_mask_start[:, None, :] | | |
length_mask[None, :, :]) | |
""" | |
# print(idxs[None, None, :].size(), rwc_depths[:, :, None].size()) | |
end_worst_case_mask = idxs[None, None, :] < 2 * rwc_depths[:, :, None] | |
mask = mask | end_worst_case_mask | |
mask_idxs = target_lengths[None, :] - 2 * lwc_depths | |
start_worst_case_mask = idxs[None, None, :] > mask_idxs[:, :, None] | |
mask = mask | start_worst_case_mask | |
""" | |
output_length = extracted_log_probs.size(0) | |
for t in range(output_length): | |
# Transition | |
curr_probs = torch.matmul(prev_probs, | |
transition) | |
# Masks | |
t_mask = mask[t] | (curr_probs == 0.) | |
# Keeping it log-safe | |
log_curr_probs = torch.log( | |
curr_probs.masked_fill(t_mask, eps) | |
).masked_fill(t_mask, log_eps) | |
# print(extracted_log_probs[t, 0]) | |
log_curr_probs = log_curr_probs + extracted_log_probs[t] | |
log_C = masked_logsumexp(log_curr_probs, t_mask) | |
# Keeping it exp-safe | |
log_curr_probs = log_curr_probs - log_C | |
exp_safe_mask = log_curr_probs < log_eps | |
prev_probs = torch.exp( | |
log_curr_probs.masked_fill(exp_safe_mask, log_eps) | |
).masked_fill(exp_safe_mask, 0.) | |
# if DEBUG: | |
# if max_length < 48: | |
# # print("%03d" % t, (~t_mask[0]).long()) | |
# try: | |
# hintplot(prev_probs[0]) | |
# except: | |
# print(prev_probs[0]) | |
# exit() | |
# print(prev_probs[0]) | |
acc += log_C[:, 0] | |
return acc | |
class Loss(nn.Module): | |
__constants__ = ['eps', 'log_eps', 'transition'] | |
def __init__(self, blank, | |
transition_function=create_transition_matrix): | |
super(Loss, self).__init__() | |
self.blank_idx = blank | |
# self.eps = torch.tensor(0.) | |
self.eps = torch.tensor(1e-8) | |
# self.log_eps = torch.tensor(float("-inf")) | |
self.log_eps = torch.tensor(-64.) | |
self.max_length = 200 | |
self.transition = transition_function(self.max_length) | |
self.first = True | |
def forward(self, log_probs, targets, target_lengths, | |
lwc_depths=None, | |
rwc_depths=None): | |
if self.first: | |
self.eps = self.eps.to(log_probs.device) | |
self.log_eps = self.log_eps.to(log_probs.device) | |
self.transition = self.transition.to(log_probs.device) | |
self.first = False | |
labels_blanks = augment_labels(targets, self.blank_idx) | |
extracted = extract_label_log_probs(log_probs, labels_blanks) | |
# extracted = torch.clamp(extracted, min=self.log_eps) | |
transition = self.transition[:extracted.size(-1), :extracted.size(-1)] | |
# lwc_depths = torch.min(lwc_depths, target_lengths[None, :] - 1) | |
# rwc_depths = torch.min(rwc_depths, target_lengths[None, :] - 1) | |
results = -forward_ctc_exp(extracted, (target_lengths * 2 + 1).long(), | |
transition, | |
# lwc_depths=lwc_depths, | |
# rwc_depths=rwc_depths, | |
eps=self.eps, log_eps=self.log_eps) | |
# results = -select_final_loss(acc_loss, target_lengths) | |
if DEBUG: | |
print(labels_blanks.size(), extracted.size()) | |
print(results) | |
return results | |
if __name__ == "__main__": | |
# Test code. | |
torch.set_printoptions(precision=5) | |
log_probs = torch.log_softmax(torch.randn(32, 5, 11), dim=-1) | |
labels = torch.randint(0, 10, size=(5, 5)) | |
seq_lengths = torch.randint(5, 6, size=(5,)) | |
seq_lengths,_ = seq_lengths.sort(descending=True) | |
#log_probs = torch.arange(11).repeat(10, 5, 1) | |
print("log_probs", log_probs.size()) | |
print("labels", labels.size()) | |
ctc = Loss(-1) | |
log_calc = ctc(log_probs, labels, seq_lengths) | |
# Exp space calculation. | |
labels_blanks = augment_labels(labels, -1) | |
extracted_exp = extract_label_log_probs(log_probs, labels_blanks).exp() | |
transition = create_transition_matrix(extracted_exp.size(2)) | |
prev_probs = torch.zeros_like(extracted_exp[0]) | |
prev_probs[:, 0] = 1. | |
for i in range(log_probs.size(0)): | |
prev_probs = torch.matmul(prev_probs, transition) * extracted_exp[i] | |
exp_calc = -torch.log(prev_probs[:, -2:].sum(1)) | |
print(log_calc, exp_calc) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment