Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Last active June 15, 2019 20:32
Show Gist options
  • Save NMZivkovic/6a3998c229dc962d9e0ab29da6d17b5c to your computer and use it in GitHub Desktop.
Save NMZivkovic/6a3998c229dc962d9e0ab29da6d17b5c to your computer and use it in GitHub Desktop.
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D
from tensorflow.keras.models import Model
class Autoencoder(object):
def __init__(self):
# Encoding
input_layer = Input(shape=(28, 28, 1))
encoding_conv_layer_1 = Conv2D(16, (3, 3), activation='relu', padding='same')(input_layer)
encoding_pooling_layer_1 = MaxPooling2D((2, 2), padding='same')(encoding_conv_layer_1)
encoding_conv_layer_2 = Conv2D(8, (3, 3), activation='relu', padding='same')(encoding_pooling_layer_1)
encoding_pooling_layer_2 = MaxPooling2D((2, 2), padding='same')(encoding_conv_layer_2)
encoding_conv_layer_3 = Conv2D(8, (3, 3), activation='relu', padding='same')(encoding_pooling_layer_2)
code_layer = MaxPooling2D((2, 2), padding='same')(encoding_conv_layer_3)
# Decoding
decodging_conv_layer_1 = Conv2D(8, (3, 3), activation='relu', padding='same')(code_layer)
decodging_upsampling_layer_1 = UpSampling2D((2, 2))(decodging_conv_layer_1)
decodging_conv_layer_2 = Conv2D(8, (3, 3), activation='relu', padding='same')(decodging_upsampling_layer_1)
decodging_upsampling_layer_2 = UpSampling2D((2, 2))(decodging_conv_layer_2)
decodging_conv_layer_3 = Conv2D(16, (3, 3), activation='relu')(decodging_upsampling_layer_2)
decodging_upsampling_layer_3 = UpSampling2D((2, 2))(decodging_conv_layer_3)
output_layer = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(decodging_upsampling_layer_3)
self._model = Model(input_layer, output_layer)
self._model.compile(optimizer='adadelta', loss='binary_crossentropy')
def train(self, input_train, input_test, batch_size, epochs):
self._model.fit(input_train,
input_train,
epochs = epochs,
batch_size=batch_size,
shuffle=True,
validation_data=(
input_test,
input_test))
def getDecodedImage(self, encoded_imgs):
decoded_image = self._model.predict(encoded_imgs)
return decoded_image
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment