-
-
Save harpone/3453185b41d8d985356cbe5e57d67342 to your computer and use it in GitHub Desktop.
import tensorflow as tf | |
from tensorflow.python.framework import ops | |
import numpy as np | |
# Define custom py_func which takes also a grad op as argument: | |
def py_func(func, inp, Tout, stateful=True, name=None, grad=None): | |
# Need to generate a unique name to avoid duplicates: | |
rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8)) | |
tf.RegisterGradient(rnd_name)(grad) # see _MySquareGrad for grad example | |
g = tf.get_default_graph() | |
with g.gradient_override_map({"PyFunc": rnd_name}): | |
return tf.py_func(func, inp, Tout, stateful=stateful, name=name) | |
# Def custom square function using np.square instead of tf.square: | |
def mysquare(x, name=None): | |
with ops.op_scope([x], name, "Mysquare") as name: | |
sqr_x = py_func(np.square, | |
[x], | |
[tf.float32], | |
name=name, | |
grad=_MySquareGrad) # <-- here's the call to the gradient | |
return sqr_x[0] | |
# Actual gradient: | |
def _MySquareGrad(op, grad): | |
x = op.inputs[0] | |
return grad * 20 * x # add a "small" error just to see the difference: | |
with tf.Session() as sess: | |
x = tf.constant([1., 2.]) | |
y = mysquare(x) | |
tf.initialize_all_variables().run() | |
print(x.eval(), y.eval(), tf.gradients(y, x)[0].eval()) |
Thank you very much guys. All this helped me a lot. Two more questions:
- could you say why examples by @IFLED and @adler-j give me the error: 'cannot create weak reference to "numpy.ufunc" object' while with @kristijanbartol it's alright?
- is it possible to deal with n-dimensional gradient? In that case how does the '# second value is not used - it can be multiplied by zero with no side effects' thing behave?
Possible solution to 2:
Using a py_func for the custom gradient?
@Nirzi: The "cannot create weak reference" error comes because we're passing np.square
as the argument to the custom py_func
, which is then passed to tf.py_func
. The problem can be solved by instead defining a new function
def numpy_square(x):
return np.square(x)
and passing that named function to py_func
, instead of the "weak reference" to a numpy function. To be honest I'm not super clear on why, but I think it has to do with how Python does object lookups.
hello guys, i can not print in grad function. How to make sure that my calculation would go fine
hi,can you show me how to use it to train the networks such as CNN,I am a beginer in this region,so when I try to train it use this method ,there always be a lot of mistakes.thanks
A complete minimalistic example with actual gradient updates could also be useful: https://gist.github.com/kristijanbartol/1b7b7c5d431415284217bbf63ca25c66