Skip to content

Instantly share code, notes, and snippets.

@jzuern
Created February 24, 2019 10:00
Show Gist options
  • Save jzuern/c8a1ac0a9516890f7945146c188415db to your computer and use it in GitHub Desktop.
Save jzuern/c8a1ac0a9516890f7945146c188415db to your computer and use it in GitHub Desktop.
from keras.layers import Input, Dense
from keras.models import Model
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neighbors import LSHForest
import matplotlib.pyplot as plt
from keras.datasets import fashion_mnist
# Autoencoder model definition
input_img = Input(shape=(784,))
encoded = Dense(128, activation='relu')(input_img)
encoded = Dense(64, activation='relu')(encoded)
encoded = Dense(32, activation='relu', name='encoded')(encoded)
decoded = Dense(64, activation='relu')(encoded)
decoded = Dense(128, activation='relu')(decoded)
decoded = Dense(784, activation='sigmoid')(decoded)
autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
# Load fashion MNIST dataset
(x_train, _), (x_test, _) = fashion_mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
# Train autoencoder on data
autoencoder.fit(x_train, x_train,
epochs=1,
batch_size=128,
shuffle=True,
validation_data=(x_test, x_test),
callbacks=[])
# Define Enocder as new model
layer_name = 'encoded'
encoder = Model(inputs=autoencoder.input,
outputs=autoencoder.get_layer(layer_name).output)
# Generate feature vectors of test dataset
x_test_encoded = encoder.predict(x_test)
x_test = np.reshape(x_test, [-1, 28, 28])
# create Local Sensitivity hashing instance for fast neighborhood search
lshf = LSHForest(random_state=42)
lshf.fit(x_test_encoded)
# Random index of query image from test set
random_query = np.random.randint(0, 1000)
query_features = np.expand_dims(x_test_encoded[random_query, :], axis=0)
distances, indices = lshf.kneighbors(query_features, n_neighbors=5)
plt.imshow(x_test[random_query, :, :])
plt.title('Query image')
plt.gray()
plt.show()
for i in range(1, 5):
ax = plt.subplot(1, 4, i)
plt.imshow(x_test[indices[0][i], :, :])
plt.gray()
plt.title('Distance = ' + str(distances[0][i]))
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment