Created
July 14, 2018 19:08
-
-
Save nairouz/035a830d1e58a3759a6a1e193f5defed to your computer and use it in GitHub Desktop.
error related to Keras custom layer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
import tensorflow.keras.backend as K | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.layers import Layer, Input | |
from tensorflow.keras.losses import kullback_leibler_divergence | |
tf.enable_eager_execution() | |
class ClusteringLayer(Layer): | |
def __init__(self, output_dim, input_dim=None, alpha=1.0, **kwargs): | |
self.output_dim = output_dim | |
self.input_dim = input_dim | |
self.alpha = alpha | |
super(ClusteringLayer, self).__init__(**kwargs) | |
def build(self, input_shape): | |
self.W = self.add_weight(name='kernel', shape=(self.output_dim, input_shape[1].value), initializer='uniform', trainable=True) | |
super(ClusteringLayer, self).build(input_shape) | |
def call(self, x, mask=None): | |
q = 1.0/(1.0 + K.sqrt(K.sum(K.square(K.expand_dims(x, 1) - self.W), axis=2))**2 /self.alpha) | |
q = q**((self.alpha+1.0)/2.0) | |
q = K.transpose(K.transpose(q)/K.sum(q, axis=1)) | |
return q | |
def compute_output_shape(self, input_shape): | |
return (input_shape[0].value, self.output_dim) | |
def clustering_loss(y_true, y_pred): | |
a = K.square(y_pred) / K.sum(y_pred, axis=0) | |
p = K.transpose(K.transpose(a) / K.sum(a, axis=1)) | |
loss = kullback_leibler_divergence(p, y_pred) | |
return loss | |
input1 = Input(shape=(10,), name="input") | |
out = ClusteringLayer(output_dim = 5, name='clustering')(input1) | |
model = Model(inputs=input1, outputs=out) | |
model.compile(optimizer=tf.train.AdamOptimizer(1e-3), loss={'clustering' : clustering_loss}) | |
X = np.random.random((20, 10)) | |
Y = np.random.random((20, 5)) | |
model.fit(x={'input' : X}, y={'clustering' : Y}, batch_size=1, epochs=10) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment