Created
November 15, 2021 11:43
-
-
Save grey-area/a39588818edf369f0152d74fe9aa35bb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
import json | |
with Path('resources/tokens.json').open() as f: | |
tokens = json.load(f) | |
def decode(predictions): | |
# 'predictions' is 2D array, shape (num_frames, 999) | |
# For each frame, the index in range [0-998] with the largest value tells | |
# you the predicted token. | |
# Add the token to the text string unless it is the same as the previous token | |
# or is equal to '_'. | |
num_frames = predictions.shape[0] | |
text = '' | |
prev_token = '_' | |
for i in range(num_frames): | |
max_index = -1 | |
max_value = -1e20 | |
for j in range(999): | |
if predictions[i, j] > max_value: | |
max_index = j | |
max_value = predictions[i, j] | |
token = tokens[max_index] | |
if token != '_' and token != prev_token: | |
text += token | |
prev_token = token | |
return text |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment