Skip to content

Instantly share code, notes, and snippets.

@eileen-code4fun
Created January 21, 2022 05:58
Show Gist options
  • Save eileen-code4fun/4d9e3d56b11c227c1ee3a58569e08f57 to your computer and use it in GitHub Desktop.
Save eileen-code4fun/4d9e3d56b11c227c1ee3a58569e08f57 to your computer and use it in GitHub Desktop.
Translation Call
class Spa2EngTranslator(tf.keras.Model):
def __init__(self, eng_text_processor, spa_text_processor, unit=512):
pass
def call(self, eng_text, spa_text):
spa_tokens = self.spa_text_processor(spa_text) # Shape: (batch, Ts)
spa_vectors = self.spa_embedding(spa_tokens) # Shape: (batch, Ts, embedding_dim)
spa_rnn_out, fhstate, fcstate, bhstate, bcstate = self.spa_rnn(spa_vectors) # Shape: (batch, Ts, bi_rnn_output_dim), (batch, rnn_output_dim) ...
spa_hstate = tf.concat([fhstate, bhstate], -1)
spa_cstate = tf.concat([fcstate, bcstate], -1)
eng_tokens = self.eng_text_processor(eng_text) # Shape: (batch, Te)
expected = eng_tokens[:,1:] # Shape: (batch, Te-1)
teacher_forcing = eng_tokens[:,:-1] # Shape: (batch, Te-1)
eng_vectors = self.eng_embedding(teacher_forcing) # Shape: (batch, Te-1, embedding_dim)
eng_in = self.attention(inputs=[eng_vectors,spa_rnn_out], mask=[eng_vectors._keras_mask, spa_rnn_out._keras_mask])
trans_vectors, _, _ = self.eng_rnn(eng_in, initial_state=[spa_hstate, spa_cstate]) # Shape: (batch, Te-1, rnn_output_dim)
out = self.out(trans_vectors) # Shape: (batch, Te-1, eng_vocab_size)
return out, expected, out._keras_mask
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment