Skip to content

Instantly share code, notes, and snippets.

@ljmartin
Created November 9, 2021 21:51
Show Gist options
  • Save ljmartin/ed3961406618cb461a786c2221e56ff9 to your computer and use it in GitHub Desktop.
Save ljmartin/ed3961406618cb461a786c2221e56ff9 to your computer and use it in GitHub Desktop.
import keras
import tensorflow as tf
class D2M(keras.layers.Layer):
"""
Converts an EDM `D` to its Gram matrix representation `M`.
"""
def call(self, inputs, **kwargs):
batch_size = tf.shape(inputs)[0]
n_atoms = tf.shape(inputs)[1]
D1j = tf.reshape(tf.tile(inputs[:, 0, :], [1, n_atoms]),
shape=(batch_size, n_atoms, n_atoms))
Di1 = tf.transpose(D1j, perm=[0, 2, 1])
M = .5 * (-inputs + D1j + Di1)
return M
class D2T(keras.layers.Layer):
"""
Converts a matrix `D` to the matrix `T = -.5 J D J`. `D` is EDM iff `T` is positive semi-definite.
"""
def __init__(self, n_atoms, **kwargs):
self.n_atoms = n_atoms
super().__init__(**kwargs)
def build(self, input_shape):
eye = tf.eye(num_rows=self.n_atoms, dtype=self.dtype)
J = eye - tf.ones(shape=(self.n_atoms, self.n_atoms), dtype=self.dtype) / float(self.n_atoms)
self.J = tf.reshape(J, shape=(-1, self.n_atoms, self.n_atoms), name="reshape_J")
super().build(input_shape)
def call(self, inputs, **kwargs):
J = tf.tile(self.J, multiples=[tf.shape(inputs)[0], 1, 1])
T = -0.5 * tf.matmul(tf.matmul(J, inputs), J)
# D is EDM iff T is positive semi-definite
return T
def edm_loss(D, n_atoms):
"""
Loss imposing a soft constraint on EDMness for the input matrix `D`.
:param D: a hollow symmetric matrix
:param n_atoms: number of atoms
:return: loss value
"""
# D is EDM iff T = -0.5 JDJ is positive semi-definite
T = D2T(n_atoms, name="D2J")(D)
J_ev = tf.linalg.eigvalsh(T)
#return tf.square(tf.nn.relu(-J_ev))
return tf.reduce_sum(tf.square(tf.nn.relu(-J_ev)), axis=-1)
mat = np.random.uniform(size=[10,10])*10
print(edm_loss(mat[np.newaxis,:,:], 10))
def np_D2T(D):
n_atoms = D.shape[0]
eye = np.eye(n_atoms)
J = eye - np.ones(shape=(n_atoms, n_atoms)) / float(n_atoms)
T = -0.5 * np.matmul(J.dot(D), J)
return T
def np_loss(D, n_atoms):
"""
Loss imposing a soft constraint on EDMness for the input matrix `D`.
:param D: a hollow symmetric matrix
:param n_atoms: number of atoms
:return: loss value
"""
# D is EDM iff T = -0.5 JDJ is positive semi-definite
T = np_D2T(D)
J_ev = np.linalg.eigvalsh(T)
#return tf.square(tf.nn.relu(-J_ev))
return np.sum(np.clip(-J_ev, 0, np.inf)**2)
np_loss(mat, 10).sum()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment