Skip to content

Instantly share code, notes, and snippets.

@eileen-code4fun
Created January 21, 2022 05:56
Show Gist options
  • Save eileen-code4fun/b3380ef2ef69d26e36ee3b2374dba7d8 to your computer and use it in GitHub Desktop.
Save eileen-code4fun/b3380ef2ef69d26e36ee3b2374dba7d8 to your computer and use it in GitHub Desktop.
Translation Code
def translate(spa_text, model, max_seq=100):
spa_tokens = model.spa_text_processor([spa_text]) # Shape: (1, Ts)
spa_vectors = model.spa_embedding(spa_tokens, training=False) # Shape: (1, Ts, embedding_dim)
spa_rnn_out, fhstate, fcstate, bhstate, bcstate = model.spa_rnn(spa_vectors, training=False) # Shape: (batch, rnn_output_dim)
spa_hstate = tf.concat([fhstate, bhstate], -1)
spa_cstate = tf.concat([fcstate, bcstate], -1)
state = [spa_hstate, spa_cstate]
print(spa_rnn_out.shape)
index_from_string = tf.keras.layers.StringLookup(
vocabulary=model.eng_text_processor.get_vocabulary(),
mask_token='')
trans = ['[START]']
vectors = []
for i in range(max_seq):
token = index_from_string([[trans[i]]]) # Shape: (1, 1)
vector = model.eng_embedding(token, training=False) # Shape: (1, 1, embedding_dim)
vectors.append(vector)
query = tf.concat(vectors, axis=1)
context = model.attention(inputs=[query, spa_rnn_out], training=False)
trans_vector, hstate, cstate = model.eng_rnn(context[:,-1:,:], initial_state=state, training=False) # Shape: (1, 1, rnn_output_dim), (1, rnn_output_dim), (1, rnn_output_dim)
state = [hstate, cstate]
out = model.out(trans_vector) # Shape: (1, 1, eng_vocab_size)
out = tf.squeeze(out) # Shape: (eng_vocab_size,)
word_index = tf.math.argmax(out)
word = model.eng_text_processor.get_vocabulary()[word_index]
trans.append(word)
if word == '[END]':
trans = trans[:-1]
break
_, atts = model.attention(inputs=[vectors, spa_rnn_out], return_attention_scores=True, training=False)
return ' '.join(trans[1:]), atts
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment