Skip to content

Instantly share code, notes, and snippets.

@XinyueZ
Last active November 11, 2022 23:02
Show Gist options
  • Save XinyueZ/1a9fc50bde93bd88536026c8665f3ecd to your computer and use it in GitHub Desktop.
Save XinyueZ/1a9fc50bde93bd88536026c8665f3ecd to your computer and use it in GitHub Desktop.
PixelRNNs Many-To-Many
initializer=keras.initializers.he_uniform
#@title Encoder
def create_encoder(input_shape):
inputs = keras.layers.Input(shape=input_shape)
x = keras.layers.Conv1D(filters=HIDDEN_SIZE,
kernel_size=4,
strides=1,
input_shape=input_shape,
kernel_initializer=initializer,
bias_initializer=initializer,
)(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Activation("relu")(x)
y = keras.layers.AvgPool1D(pool_size=4, strides=1)(x)
encoder = keras.Model(inputs=[inputs], outputs=[y], name="encoder")
return encoder, y.shape
encoder, encoded_shape = create_encoder(image_flatten.shape)
#@title Decoder
def create_decoder(input_shape):
inputs = keras.layers.Input(shape=input_shape)
x = keras.layers.Bidirectional(keras.layers.LSTM(
units=HIDDEN_SIZE,
return_sequences=True,
kernel_initializer=initializer,
recurrent_initializer=initializer,
bias_initializer=initializer))(inputs)
x = keras.layers.LSTM(units=HIDDEN_SIZE,
return_sequences=True,
kernel_initializer=initializer,
recurrent_initializer=initializer,
bias_initializer=initializer)(x)
x = keras.layers.Flatten(dtype='float32')(x)
y = keras.layers.BatchNormalization()(x)
decoder = keras.Model(inputs=[inputs], outputs=[y], name="decoder")
return decoder, y.shape
decoder, decoded_shape=create_decoder(encoded_shape[1:])
#@title Predictor
def create_predictor(input_shape):
inputs = keras.layers.Input(shape=input_shape)
x = keras.layers.Dense(input_size,
kernel_initializer=initializer,
bias_initializer=initializer)(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Activation("relu")(x)
y = keras.layers.Activation("tanh")(x)
predictor = keras.Model(inputs=[inputs], outputs=[y], name="linear_predictor")
return predictor, y.shape
predictor, predictor_shape=create_predictor(decoded_shape[1:])
#@title Ensemble model
inputs = keras.layers.Input(image_flatten.shape)
enc_dec_prd=predictor(decoder(encoder(inputs)))
model = keras.Model(inputs=[inputs], outputs=[enc_dec_prd], name="pixelRNN")
model.summary()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment