Skip to content

Instantly share code, notes, and snippets.

@adler-j
Created February 27, 2017 11:33
Show Gist options
  • Save adler-j/8d571eecbb1a6bbc0424cfaaee7a01aa to your computer and use it in GitHub Desktop.
Save adler-j/8d571eecbb1a6bbc0424cfaaee7a01aa to your computer and use it in GitHub Desktop.
# http://stackoverflow.com/questions/39921607/tensorflow-how-to-make-a-custom-activation-function-with-only-python
import tensorflow as tf
from tensorflow.python.framework import ops
import numpy as np
import odl
matrix = np.array([[1, 2],
[0, 0],
[0, 1]], dtype='float32')
dom = odl.rn(2, dtype='float32')
ran = odl.rn(3, dtype='float32')
odl_op = odl.MatrixOperator(matrix, dom, ran)
# Define custom py_func which takes also a grad op as argument:
def py_func(func, inp, Tout, stateful=True, name=None, grad=None):
if grad is None:
return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
else:
# Need to generate a unique name to avoid duplicates:
rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))
tf.RegisterGradient(rnd_name)(grad)
g = tf.get_default_graph()
with g.gradient_override_map({"PyFunc": rnd_name}):
return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
# Def custom square function using np.square instead of tf.square:
def mysquaregrad(x, dx, name=None):
def _impl(x, dx):
return np.asarray(odl_op.derivative(x).adjoint(dx))
with ops.name_scope(name, "MySquareGrad", [x]) as name:
sqr_x = py_func(_impl,
[x, dx],
[tf.float32],
name=name)
return sqr_x[0]
# Actual gradient:
def _MySquareGrad(op, grad):
x = op.inputs[0]
return mysquaregrad(x, grad)
# Def custom square function using np.square instead of tf.square:
def mysquare(x, name=None):
def my_func(x):
return np.asarray(odl_op(x))
with ops.name_scope(name, "Mysquare", [x]) as name:
sqr_x = py_func(my_func,
[x],
[tf.float32],
name=name,
grad=_MySquareGrad) # <-- here's the call to the gradient
return sqr_x[0]
with tf.Session() as sess:
x = tf.constant([1., 2.])
y = mysquare(x)
tf.global_variables_initializer().run()
print(x.eval(), y.eval(), tf.gradients(y, x)[0].eval())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment