Skip to content

Instantly share code, notes, and snippets.

@grey-area
Created November 15, 2021 11:43
Show Gist options
  • Save grey-area/a39588818edf369f0152d74fe9aa35bb to your computer and use it in GitHub Desktop.
Save grey-area/a39588818edf369f0152d74fe9aa35bb to your computer and use it in GitHub Desktop.
from pathlib import Path
import json
with Path('resources/tokens.json').open() as f:
tokens = json.load(f)
def decode(predictions):
# 'predictions' is 2D array, shape (num_frames, 999)
# For each frame, the index in range [0-998] with the largest value tells
# you the predicted token.
# Add the token to the text string unless it is the same as the previous token
# or is equal to '_'.
num_frames = predictions.shape[0]
text = ''
prev_token = '_'
for i in range(num_frames):
max_index = -1
max_value = -1e20
for j in range(999):
if predictions[i, j] > max_value:
max_index = j
max_value = predictions[i, j]
token = tokens[max_index]
if token != '_' and token != prev_token:
text += token
prev_token = token
return text
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment