Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active September 15, 2024 20:40
Show Gist options
  • Save vadimkantorov/a247c4fb375bdf60633a0e3ea740e87b to your computer and use it in GitHub Desktop.
Save vadimkantorov/a247c4fb375bdf60633a0e3ea740e87b to your computer and use it in GitHub Desktop.
Looks up variable-length UTF-8 byte tokens from a vocab and concats them together in pure PyTorch vectorized way without loops
# works only with a single, non-batched tensor of token ids
import torch
def bpedetokenize_loop(token_ids, token_utf8bytes, token_lens):
inds = torch.cat((torch.zeros_like(token_lens[:1]), token_lens.cumsum(-1)))
return torch.cat([token_utf8bytes[inds[i]:inds[i_p_1]] for i, i_p_1 in zip(token_ids, token_ids + 1)])
def bpedetokenize_vec(token_ids, token_utf8bytes, token_lens):
inds_begin = torch.cat((torch.zeros_like(token_lens[:1]), token_lens[:-1].cumsum(-1)))
inds_end = inds_begin + token_lens
begins = inds_begin[token_ids]
ends = inds_end[token_ids]
lens = token_lens[token_ids]
ones = torch.ones_like(token_ids)
begins_shifted = torch.cat([begins[:1], begins[1:] - ends[:-1] + 1])
repeats = torch.stack([ones, lens - 1], dim = -1).flatten()
i = torch.stack([begins_shifted, ones], dim = -1).flatten()
I = i.repeat_interleave(repeats).cumsum(-1)
return token_utf8bytes[I]
if __name__ == '__main__':
token_ids = torch.tensor([1, 0, 1, 3], dtype = torch.int64)
token_utf8bytes = torch.tensor([1, 17, 31, 2, 2, 2, 2, 3, 7], dtype = torch.uint8)
token_lens = torch.tensor([1, 2, 4, 2], dtype = torch.int64)
print('loop:', bpedetokenize_loop(token_ids, token_utf8bytes, token_lens))
print(' vec:', bpedetokenize_vec (token_ids, token_utf8bytes, token_lens))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment