Skip to content

Instantly share code, notes, and snippets.

@nairouz
Created July 14, 2018 19:08
Show Gist options
  • Save nairouz/035a830d1e58a3759a6a1e193f5defed to your computer and use it in GitHub Desktop.
Save nairouz/035a830d1e58a3759a6a1e193f5defed to your computer and use it in GitHub Desktop.
error related to Keras custom layer
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer, Input
from tensorflow.keras.losses import kullback_leibler_divergence
tf.enable_eager_execution()
class ClusteringLayer(Layer):
def __init__(self, output_dim, input_dim=None, alpha=1.0, **kwargs):
self.output_dim = output_dim
self.input_dim = input_dim
self.alpha = alpha
super(ClusteringLayer, self).__init__(**kwargs)
def build(self, input_shape):
self.W = self.add_weight(name='kernel', shape=(self.output_dim, input_shape[1].value), initializer='uniform', trainable=True)
super(ClusteringLayer, self).build(input_shape)
def call(self, x, mask=None):
q = 1.0/(1.0 + K.sqrt(K.sum(K.square(K.expand_dims(x, 1) - self.W), axis=2))**2 /self.alpha)
q = q**((self.alpha+1.0)/2.0)
q = K.transpose(K.transpose(q)/K.sum(q, axis=1))
return q
def compute_output_shape(self, input_shape):
return (input_shape[0].value, self.output_dim)
def clustering_loss(y_true, y_pred):
a = K.square(y_pred) / K.sum(y_pred, axis=0)
p = K.transpose(K.transpose(a) / K.sum(a, axis=1))
loss = kullback_leibler_divergence(p, y_pred)
return loss
input1 = Input(shape=(10,), name="input")
out = ClusteringLayer(output_dim = 5, name='clustering')(input1)
model = Model(inputs=input1, outputs=out)
model.compile(optimizer=tf.train.AdamOptimizer(1e-3), loss={'clustering' : clustering_loss})
X = np.random.random((20, 10))
Y = np.random.random((20, 5))
model.fit(x={'input' : X}, y={'clustering' : Y}, batch_size=1, epochs=10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment