Last active
September 15, 2024 20:40
-
-
Save vadimkantorov/a247c4fb375bdf60633a0e3ea740e87b to your computer and use it in GitHub Desktop.
Looks up variable-length UTF-8 byte tokens from a vocab and concats them together in pure PyTorch vectorized way without loops
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# works only with a single, non-batched tensor of token ids | |
import torch | |
def bpedetokenize_loop(token_ids, token_utf8bytes, token_lens): | |
inds = torch.cat((torch.zeros_like(token_lens[:1]), token_lens.cumsum(-1))) | |
return torch.cat([token_utf8bytes[inds[i]:inds[i_p_1]] for i, i_p_1 in zip(token_ids, token_ids + 1)]) | |
def bpedetokenize_vec(token_ids, token_utf8bytes, token_lens): | |
inds_begin = torch.cat((torch.zeros_like(token_lens[:1]), token_lens[:-1].cumsum(-1))) | |
inds_end = inds_begin + token_lens | |
begins = inds_begin[token_ids] | |
ends = inds_end[token_ids] | |
lens = token_lens[token_ids] | |
ones = torch.ones_like(token_ids) | |
begins_shifted = torch.cat([begins[:1], begins[1:] - ends[:-1] + 1]) | |
repeats = torch.stack([ones, lens - 1], dim = -1).flatten() | |
i = torch.stack([begins_shifted, ones], dim = -1).flatten() | |
I = i.repeat_interleave(repeats).cumsum(-1) | |
return token_utf8bytes[I] | |
if __name__ == '__main__': | |
token_ids = torch.tensor([1, 0, 1, 3], dtype = torch.int64) | |
token_utf8bytes = torch.tensor([1, 17, 31, 2, 2, 2, 2, 3, 7], dtype = torch.uint8) | |
token_lens = torch.tensor([1, 2, 4, 2], dtype = torch.int64) | |
print('loop:', bpedetokenize_loop(token_ids, token_utf8bytes, token_lens)) | |
print(' vec:', bpedetokenize_vec (token_ids, token_utf8bytes, token_lens)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment