Last active
June 12, 2023 20:24
-
-
Save fauxneticien/9976752d7c11619c720e99d6ef8e1d7a to your computer and use it in GitHub Desktop.
Lhotse token collator for CTC
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Modified version with diff (see history) | |
class TokenCollater: | |
"""Collate list of tokens | |
Map sentences to integers. Sentences are padded to equal length. | |
Beginning and end-of-sequence symbols can be added. | |
Call .inverse(tokens_batch, tokens_lens) to reconstruct batch as string sentences. | |
Example: | |
>>> token_collater = TokenCollater(cuts) | |
>>> tokens_batch, tokens_lens = token_collater(cuts.subset(first=32)) | |
>>> original_sentences = token_collater.inverse(tokens_batch, tokens_lens) | |
Returns: | |
tokens_batch: IntTensor of shape (B, L) | |
B: batch dimension, number of input sentences | |
L: length of the longest sentence | |
tokens_lens: IntTensor of shape (B,) | |
Length of each sentence after adding <eos> and <bos> | |
but before padding. | |
""" | |
def __init__( | |
self, | |
cuts: CutSet, | |
add_eos: bool = True, | |
add_bos: bool = True, | |
add_unk: bool = True, | |
pad_symbol: str = "<pad>", | |
bos_symbol: str = "<bos>", | |
eos_symbol: str = "<eos>", | |
unk_symbol: str = "<unk>", | |
): | |
self.pad_symbol = pad_symbol | |
self.bos_symbol = bos_symbol | |
self.eos_symbol = eos_symbol | |
self.unk_symbol = unk_symbol | |
self.add_eos = add_eos | |
self.add_bos = add_bos | |
tokens = {char for cut in cuts for char in cut.supervisions[0].text} | |
tokens_unique = ( | |
[pad_symbol] | |
+ ([unk_symbol] if add_unk else []) | |
+ ([bos_symbol] if add_bos else []) | |
+ ([eos_symbol] if add_eos else []) | |
+ sorted(tokens) | |
) | |
self.token2idx = {token: idx for idx, token in enumerate(tokens_unique)} | |
self.idx2token = [token for token in tokens_unique] | |
def __call__(self, cuts: CutSet) -> Tuple[torch.Tensor, torch.Tensor]: | |
token_sequences = [ | |
" ".join(supervision.text for supervision in cut.supervisions) | |
for cut in cuts | |
] | |
max_len = len(max(token_sequences, key=len)) | |
seqs = [ | |
([self.bos_symbol] if self.add_bos else []) | |
+ list(seq) | |
+ ([self.eos_symbol] if self.add_eos else []) | |
+ [self.pad_symbol] * (max_len - len(seq)) | |
for seq in token_sequences | |
] | |
tokens_batch = torch.from_numpy( | |
np.array( | |
[[self.token2idx[token] for token in seq] for seq in seqs], | |
dtype=np.int64, | |
) | |
) | |
tokens_lens = torch.IntTensor( | |
[ | |
len(seq) + int(self.add_eos) + int(self.add_bos) | |
for seq in token_sequences | |
] | |
) | |
return tokens_batch, tokens_lens | |
def inverse( | |
self, tokens_batch: torch.LongTensor, tokens_lens: torch.IntTensor | |
) -> List[str]: | |
start = 1 if self.add_bos else 0 | |
sentences = [ | |
"".join( | |
[ | |
self.idx2token[idx] | |
for idx in tokens_list[start : end - int(self.add_eos)] | |
] | |
) | |
for tokens_list, end in zip(tokens_batch, tokens_lens) | |
] | |
return sentences |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment