Last active
October 9, 2023 17:56
-
-
Save jonnyli1125/e5bab12ed6f36711c57807b7f1528f3a to your computer and use it in GitHub Desktop.
RNN Transducer in ~100 lines of NumPy code. Paper: https://arxiv.org/abs/1211.3711
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
import numpy as np | |
vocab = [None, 'a', 'b', 'c'] | |
null_idx = 0 | |
V = len(vocab) | |
@dataclass | |
class LSTMWeights: | |
xi: list[list[float]] # (V-1, V) | |
hi: list[list[float]] # (V, V) | |
si: list[list[float]] # (V, V) | |
xf: list[list[float]] # (V-1, V) | |
hf: list[list[float]] # (V, V) | |
sf: list[list[float]] # (V, V) | |
xs: list[list[float]] # (V-1, V) | |
hs: list[list[float]] # (V, V) | |
xo: list[list[float]] # (V-1, V) | |
ho: list[list[float]] # (V, V) | |
sh: list[list[float]] # (V, V) | |
@dataclass | |
class RNNTransducerWeights: | |
trans_f: LSTMWeights | |
trans_b: LSTMWeights | |
pred: LSTMWeights | |
def one_hot(x: list[int], embed_size: int) -> list[list[float]]: | |
# transforms sequence of token indexes to a sequence of vectors | |
# each vector is size embed_size and set to 1 only at the index of the original number, otherwise 0 | |
# example: one_hot([1,2], 4) -> [[0,1,0,0], [0,0,1,0]] | |
seq = np.zeros((len(x), embed_size)) | |
seq[np.arange(len(x)), x] = 1 | |
return seq | |
def softmax(x: list[float]) -> list[float]: | |
# forces x to be a valid probability distribution | |
# i.e. sum(x) is 1 and x_i is between 0 and 1 for all x_i in x | |
exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) | |
return exp_x / np.sum(exp_x, axis=-1, keepdims=True) | |
def sigmoid(x: list[float]) -> list[float]: | |
# forces each element of x to be between 0 and 1 | |
return 1 / (1 + np.exp(-x)) | |
def lstm(seq: list[list[float]], hidden_size: int, W: LSTMWeights) -> list[float]: | |
input_gate, forget_gate, output_gate, state, hidden = [np.zeros(hidden_size) for i in range(5)] | |
for x in seq: | |
input_gate = sigmoid(W.xi.T @ x + W.hi.T @ hidden + W.si.T @ state) | |
forget_gate = sigmoid(W.xf.T @ x + W.hf.T @ hidden + W.sf.T @ state) | |
state = forget_gate * state + input_gate * np.tanh(W.xs.T @ x + W.hs.T @ hidden) | |
output_gate = sigmoid(W.xo.T @ x + W.ho.T @ hidden + W.sh.T @ state) | |
hidden = output_gate * np.tanh(state) | |
return hidden | |
def prediction_network(seq_y: list[list[float]], W: LSTMWeights) -> list[float]: | |
return lstm(seq_y, V, W) | |
def transcription_network(seq_x: list[list[float]], W_f: LSTMWeights, W_b: LSTMWeights) -> list[float]: | |
return lstm(seq_x, V, W_f) + lstm(reversed(seq_x), V, W_b) | |
def joiner_network(seq_x: list[list[float]], seq_y: list[list[float]], W: RNNTransducerWeights) -> list[float]: | |
return transcription_network(seq_x, W.trans_f, W.trans_b) + prediction_network(seq_y, W.pred) | |
def rnn_transducer(seq_x: list[list[float]], seq_y: list[int], W: RNNTransducerWeights) -> list[float]: | |
seq_y_one_hot = one_hot(seq_y, len(vocab) - 1) | |
logits = joiner_network(seq_x, seq_y_one_hot, W) | |
return softmax(logits) | |
@dataclass | |
class Hypothesis: | |
seq: list[int] | |
logp: float | |
def decode_beam_search(input_seq: list[list[float]], W: RNNTransducerWeights, beam_size: int = 2) -> str: | |
B = [Hypothesis([], 0)] | |
for t in range(len(input_seq)): | |
A = B | |
B = [] | |
# prefix boosting | |
for y in A: | |
boost_p = 0 | |
for y_hat in A: | |
if len(y_hat) < len(y) and y_hat == y[:len(y_hat)]: | |
y_hat_to_y_logp = 0 | |
for u in range(len(y_hat.seq) + 1, len(y.seq)): | |
log_probs = np.log(rnn_transducer(input_seq[:t], y.seq[:u-1], W)) | |
y_hat_to_y_logp += y_hat.logp + log_probs[y.seq[u]] | |
boost_p += np.exp(y_hat_to_y_logp) | |
if boost_p: | |
y.logp = np.log(np.exp(y.logp) + boost_p) | |
# main decoding loop | |
y_star_idx, y_star = max(enumerate(A), key=lambda idx, hyp: hyp.logp) | |
while len([y for y in B if y.logp > y_star.logp]) < beam_size: | |
del A[y_star_idx] | |
log_probs = np.log(rnn_transducer(input_seq[:t], y_star.seq, W)) | |
for k in range(len(log_probs)): | |
if k == null_idx: | |
B.append(Hypothesis(y_star.seq, y_star.logp + log_probs[k])) | |
else: | |
A.append(Hypothesis(y_star.seq + [k], y_star.logp + log_probs[k])) | |
y_star_idx, y_star = max(enumerate(A), key=lambda idx, hyp: hyp.logp) | |
# only keep top `beam_size` elements | |
B = sorted(B, key=lambda hyp: hyp.logp)[:beam_size] | |
# take best length normalized hypothesis | |
best_hyp = max(B, key=lambda hyp: hyp.logp / len(hyp.seq)) | |
# de-tokenize | |
return ''.join(vocab[idx] for idx in best_hyp.seq) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment