Skip to content

Instantly share code, notes, and snippets.

@kmcnaught
Last active July 16, 2018 11:27
Show Gist options
  • Save kmcnaught/e499cf89b76fa05b9bfb75f5b419404b to your computer and use it in GitHub Desktop.
Save kmcnaught/e499cf89b76fa05b9bfb75f5b419404b to your computer and use it in GitHub Desktop.
Compare proposed matrix exponential gradients with version comprised of existing differentiable ops, for symmetric matrices. See https://github.com/tensorflow/tensorflow/issues/15465. Based on https://gist.github.com/tvercaut/bd9fe8c5d12ab529babd9bf5434d7cda
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
import numpy as np
import scipy.linalg
def my_omexpovx_func(x):
"""
A utility function to get the scalar function (1-exp(-x))/x
"""
# Start by the basic formula
out = np.divide( (1-np.exp(-x)), x )
# if the denominator is small, resort to a Taylor expansion
# TODO find a properly motivated threshold for switching to the Taylor expansion
idx = (abs(x) < 1e-6)
x2idx = np.square(x[idx])
out[idx] = 1 - x[idx]/2 + x2idx/6
return out
def expm_grad_from_najfeld(M, expM):
"""
Compute the derivative of the matrix exponential
For the general case in which the matrix may not be diagonalisable,
Najfeld and Havel also provide two other formulas using the adjoint
of the matrix and some matrix functionns
"""
k2 = M.size
k = int(np.sqrt(k2))
assert(k * k == k2)
# some useful precomputations
IkeM = np.kron(np.eye(k), expM)
adM = np.kron(np.eye(k), M) - np.kron(M.T, np.eye(k))
# Apply my_func as a matrix function to the adjoint
funadM = scipy.linalg.funm(adM, my_omexpovx_func )
# Use the first formula from Najfeld and Havel
return np.dot(IkeM, funadM)
def expm_grad_from_frechet(M):
"""
Compute the derivative of the matrix exponential by looking
at the directional derivative (Frechet derivative) across
the canonical basis
"""
k2 = M.size
k = int(np.sqrt(k2))
assert(k * k == k2)
DexpM = np.zeros((k2,k2))
for i in range(0,k):
for j in range(0,k):
E = np.zeros((k,k))
E[i,j] = 1
eFME = scipy.linalg.expm_frechet(M,E,compute_expm=False)
DexpM[:,j*k+i] = vec(eFME)
return DexpM
def vec(mat):
"""
vectorise a matrix into a vector
"""
return mat.ravel('F')
def unvec(v):
"""
Unvectorise a vector representing a square matrix
"""
k = int(np.sqrt(len(v)))
assert(k * k == len(v))
return v.reshape((k, k), order='F')
def cvec(mat):
"""
vectorise a matrix into a vector (column major mode)
"""
return mat.ravel('C')
def uncvec(v):
"""
Unvectorise a vector (column major mode) representing a square matrix
"""
k = int(np.sqrt(len(v)))
assert(k * k == len(v))
return v.reshape((k, k), order='C')
@ops.RegisterGradient("MatrixExponential")
def _expm_grad(op, grad):
# We want the backward-mode gradient. Let X be the NxN input matrix.
# Let J(X) be the the N^2xN^2 complete Jacobian matrix of expm at X.
# Let Y be the NxN previous gradient in the backward AD. We want
# unvec( ( vec(Y)^T . J(X) )^T )
# = unvec( J(X)^T . vec(Y) )
# = unvec( J(X^T) . vec(Y) )
# which is the forward-mode derivative applied to the transpose
grad_func = lambda x, y: scipy.linalg.expm_frechet(x, y, compute_expm=False)
return tf.py_func(grad_func, [tf.transpose(op.inputs[0]), grad], tf.float64)
def expm_eig(A):
D, U = tf.self_adjoint_eig(A)
eD = tf.diag(tf.exp(D))
eA = tf.matmul(U,tf.matmul(eD,tf.transpose(U)));
return eA
tf.enable_eager_execution()
tfe = tf.contrib.eager
# generate random symmetric square matrix
d = 5
x_np = np.random.rand(d,d)
x = tfe.Variable(x_np + x_np.T)
# compute forward pass and record gradients
with tf.GradientTape(persistent=True) as tape:
expx = tf.linalg.expm(x)
expx_eig = expm_eig(x)
grad_exp = tape.gradient(expx, x)
grad_exp_eig = tape.gradient(expx_eig, x)
# compare results visually
np.set_printoptions(precision=3)
print("\nx=\n {}\n".format(x.numpy()))
print("expm(x)=\n {}".format(expx))
print("expm_eig(x)=\n {}\n".format(expx_eig))
print("grad expm =\n {}".format(grad_exp))
print("grad expm_eig =\n {}".format(grad_exp_eig))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment