Created
August 22, 2023 14:47
-
-
Save pythonlessons/589fe8741ed7527ae86f9296b748259c to your computer and use it in GitHub Desktop.
build_transformer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def Transformer( | |
input_vocab_size: int, | |
target_vocab_size: int, | |
encoder_input_size: int = None, | |
decoder_input_size: int = None, | |
num_layers: int=6, | |
d_model: int=512, | |
num_heads: int=8, | |
dff: int=2048, | |
dropout_rate: float=0.1, | |
) -> tf.keras.Model: | |
""" | |
A custom TensorFlow model that implements the Transformer architecture. | |
Args: | |
input_vocab_size (int): The size of the input vocabulary. | |
target_vocab_size (int): The size of the target vocabulary. | |
encoder_input_size (int): The size of the encoder input sequence. | |
decoder_input_size (int): The size of the decoder input sequence. | |
num_layers (int): The number of layers in the encoder and decoder. | |
d_model (int): The dimensionality of the model. | |
num_heads (int): The number of heads in the multi-head attention layer. | |
dff (int): The dimensionality of the feed-forward layer. | |
dropout_rate (float): The dropout rate. | |
Returns: | |
A TensorFlow Keras model. | |
""" | |
inputs = [ | |
tf.keras.layers.Input(shape=(encoder_input_size,), dtype=tf.int64), | |
tf.keras.layers.Input(shape=(decoder_input_size,), dtype=tf.int64) | |
] | |
encoder_input, decoder_input = inputs | |
encoder = Encoder(num_layers=num_layers, d_model=d_model, num_heads=num_heads, dff=dff, vocab_size=input_vocab_size, dropout_rate=dropout_rate)(encoder_input) | |
decoder = Decoder(num_layers=num_layers, d_model=d_model, num_heads=num_heads, dff=dff, vocab_size=target_vocab_size, dropout_rate=dropout_rate)(decoder_input, encoder) | |
output = tf.keras.layers.Dense(target_vocab_size)(decoder) | |
return tf.keras.Model(inputs=inputs, outputs=output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment