Last active
November 7, 2022 06:07
-
-
Save TeaPoly/234429e6c2d74d10fcb4987bc541d528 to your computer and use it in GitHub Desktop.
The implementation of Minimum Word Error Rate Training loss (MWER) based on negative sampling strategy from <Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
# -*- coding: utf-8 -*- | |
# Copyright 2022 Lucky Wong | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License | |
"""Minimum Word Error Rate Training loss | |
<Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition> | |
https://arxiv.org/abs/2206.08317 | |
<Minimum Word Error Rate Training for Attention-based Sequence-to-Sequence Models> | |
https://arxiv.org/abs/1712.01818 | |
""" | |
from typing import List, Optional, Tuple | |
import torch | |
MIN_LOG_VAL = torch.tensor(-float('inf')) | |
IGNORE_ID = -1 | |
def make_pad_mask(lengths: torch.Tensor, max_len: int = None) -> torch.Tensor: | |
"""Make mask tensor containing indices of padded part. | |
See description of make_non_pad_mask. | |
Args: | |
lengths (torch.Tensor): Batch of lengths (B,). | |
Returns: | |
torch.Tensor: Mask tensor containing indices of padded part. | |
Examples: | |
>>> lengths = [5, 3, 2] | |
>>> make_pad_mask(lengths) | |
masks = [[0, 0, 0, 0 ,0], | |
[0, 0, 0, 1, 1], | |
[0, 0, 1, 1, 1]] | |
""" | |
batch_size = int(lengths.size(0)) | |
if max_len is None: | |
max_len = int(lengths.max().item()) | |
seq_range = torch.arange(0, | |
max_len, | |
dtype=torch.int64, | |
device=lengths.device) | |
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) | |
seq_length_expand = lengths.unsqueeze(-1) | |
mask = seq_range_expand >= seq_length_expand | |
return mask | |
def create_sampling_mask(log_softmax, n): | |
""" | |
Generate sampling mask | |
# Ref: <Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition> | |
# https://arxiv.org/abs/2206.08317 | |
Args: | |
log_softmax: log softmax inputs, float32 (batch, maxlen_out, vocab_size) | |
n: candidate paths num, int32 | |
Return: | |
sampling_mask: the sampling mask (nbest, batch, maxlen_out, vocab_size) | |
""" | |
b, s, v = log_softmax.size() | |
# Generate random mask | |
nbest_random_mask = torch.randint( | |
0, 2, (n, b, s, v), device=log_softmax.device | |
) | |
# Greedy search decoding for best path | |
top1_score_indices = log_softmax.argmax(dim=-1).squeeze(-1) | |
# Genrate top 1 score token mask | |
top1_score_indices_mask = torch.zeros( | |
(b, s, v), dtype=torch.int).to(log_softmax.device) | |
top1_score_indices_mask.scatter_(-1, top1_score_indices.unsqueeze(-1), 1) | |
# Genrate sampling mask by applying random mask to top 1 score token | |
sampling_mask = nbest_random_mask*top1_score_indices_mask.unsqueeze(0) | |
return sampling_mask | |
def negative_sampling_decoder( | |
logit: torch.Tensor, | |
nbest: int = 4, | |
masks: Optional[torch.Tensor] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Generate multiple candidate paths by negative sampling strategy | |
# Ref: <Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition> | |
# https://arxiv.org/abs/2206.08317 | |
Args: | |
logit: logit inputs, float32 (batch, maxlen_out, vocab_size) | |
nbest: candidate paths num, int32 | |
masks: logit lengths, (batch, maxlen_out) | |
Return: | |
nbest_log_distribution: the N-BEST distribution of candidate path (nbest, batch) | |
nbest_pred: the NBEST candidate path (nbest, batch, maxlen_out) | |
""" | |
# Using log-softmax for probability distribution | |
log_softmax = torch.nn.functional.log_softmax(logit, dim=-1) | |
# Generate sampling mask | |
with torch.no_grad(): | |
sampling_mask = create_sampling_mask(log_softmax, nbest) | |
# Randomly masking top1 score with -float('inf') | |
# (nbest, batch, maxlen_out, vocab_size) | |
nbest_log_softmax = torch.where( | |
sampling_mask != 0, MIN_LOG_VAL.type_as(log_softmax), log_softmax) | |
# Greedy search decoding for sampling log softmax | |
nbest_logsoftmax, nbest_pred = nbest_log_softmax.topk(1) | |
nbest_pred = nbest_pred.squeeze(-1) | |
nbest_logsoftmax = nbest_logsoftmax.squeeze(-1) | |
# Construct N-BEST log PDF | |
# FIXME (huanglk): Ignore irrelevant probabilities | |
# (n, b, s) -> (n, b): log(p1*p2*...pn) = log(p1)+log(p2)+...log(pn) | |
nbest_log_distribution = torch.sum( | |
nbest_logsoftmax.masked_fill(masks, 0), -1) | |
return nbest_log_distribution, nbest_pred | |
def compute_mwer_loss( | |
nbest_log_distribution=torch.Tensor, | |
nbest_pred=torch.Tensor, | |
tgt=torch.Tensor, | |
tgt_lens=torch.Tensor | |
): | |
""" | |
Compute Minimum Word Error Rate Training loss. | |
# Ref: <Minimum Word Error Rate Training for Attention-based Sequence-to-Sequence Models> | |
# https://arxiv.org/abs/1712.01818 | |
Args: | |
nbest_log_distribution: the N-BEST distribution of candidate path (nbest, batch) | |
nbest_pred: the NBEST candidate path (nbest, batch, maxlen_out) | |
tgt: padded target token ids, int32 (batch, maxlen_out) | |
tgt_lens: target token lengths of this batch (batch,) | |
Return: | |
loss: normalized MWER loss (batch,) | |
""" | |
n, b, s = nbest_pred.size() | |
# necessary to filter irrelevant length | |
# (b,) -> (b, s) | |
# not include <eos/sos> | |
masks = make_pad_mask(tgt_lens, max_len=tgt.size()[1]) | |
tgt = tgt.masked_fill(masks, IGNORE_ID) | |
# (n, b, s) | |
nbest_pred = nbest_pred.masked_fill(masks, IGNORE_ID) | |
# Construct number of word errors | |
# (b, s) -> (n, b, s) | |
tgt = tgt.unsqueeze(0).repeat(n, 1, 1) | |
# convert to float for normalize | |
# (n, b, s) -> (n, b) | |
nbest_word_err_num = torch.sum((tgt != nbest_pred), -1).float() | |
# Computes log distribution | |
# (n, b) -> (b,): log( p1+p2+...+pn ) = log( exp(log_p1)+exp(log_p2)+...+exp(log_pn) ) | |
sum_nbest_log_distribution = torch.logsumexp(nbest_log_distribution, 0) | |
# Re-normalized over just the N-best hypotheses. | |
# (n, b) - (b,) -> (n, b): exp(log_p)/exp(log_p_sum) = exp(log_p-log_p_sum) | |
normal_nbest_distribution = torch.exp( | |
nbest_log_distribution-sum_nbest_log_distribution) | |
# Average number of word errors over the N-best hypohtheses | |
# (n, b) -> (b) | |
mean_word_err_num = torch.mean(nbest_word_err_num, 0) | |
# print("mean_word_err_num:", mean_word_err_num) | |
# Re-normalized error word number over just the N-best hypotheses | |
# (n, b) - (b,) -> (n, b) | |
normal_nbest_word_err_num = nbest_word_err_num - mean_word_err_num | |
# Expected number of word errors over the training set. | |
# (n, b) -> (b,) | |
mwer_loss = torch.sum(normal_nbest_distribution * | |
normal_nbest_word_err_num, 0) | |
return mwer_loss | |
class Seq2seqMwerLoss(torch.nn.Module): | |
"""Minimum Word Error Rate Training loss based on the negative sampling strategy | |
<Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition> | |
https://arxiv.org/abs/2206.08317 | |
<Minimum Word Error Rate Training for Attention-based Sequence-to-Sequence Models> | |
https://arxiv.org/abs/1712.01818 | |
Args: | |
candidate_paths_num (int): The number of candidate paths. | |
""" | |
def __init__( | |
self, | |
candidate_paths_num: int = 4, | |
reduction="mean", | |
): | |
super().__init__() | |
self.candidate_paths_num = candidate_paths_num | |
self.reduction = reduction | |
def forward(self, logit: torch.Tensor, tgt: torch.Tensor, tgt_lens: torch.Tensor) -> torch.Tensor: | |
""" | |
Args: | |
logit: logit (batch, maxlen_out, vocab_size) | |
tgt: padded target token ids, int64 (batch, maxlen_out) | |
tgt_lens: target lengths of this batch (batch) | |
Return: | |
loss: normalized MWER loss | |
""" | |
assert tgt_lens.size()[0] == tgt.size()[0] == logit.size()[0] | |
assert logit.size()[1] == tgt.size()[1] | |
# Randomly mask the top1 score to generate multiple candidate paths | |
masks = make_pad_mask(tgt_lens, max_len=tgt.size()[1]) | |
nbest_log_distribution, nbest_pred = negative_sampling_decoder( | |
logit, self.candidate_paths_num, masks) | |
# Compute MWER loss | |
mwer_loss = compute_mwer_loss( | |
nbest_log_distribution, nbest_pred, tgt, tgt_lens) | |
if self.reduction == "sum": | |
return torch.sum(mwer_loss) | |
elif self.reduction == "mean": | |
return torch.mean(mwer_loss) | |
else: | |
return mwer_loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment