Skip to content

Instantly share code, notes, and snippets.

@harpone
Last active August 1, 2023 00:59
Show Gist options
  • Save harpone/3453185b41d8d985356cbe5e57d67342 to your computer and use it in GitHub Desktop.
Save harpone/3453185b41d8d985356cbe5e57d67342 to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow.python.framework import ops
import numpy as np
# Define custom py_func which takes also a grad op as argument:
def py_func(func, inp, Tout, stateful=True, name=None, grad=None):
# Need to generate a unique name to avoid duplicates:
rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))
tf.RegisterGradient(rnd_name)(grad) # see _MySquareGrad for grad example
g = tf.get_default_graph()
with g.gradient_override_map({"PyFunc": rnd_name}):
return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
# Def custom square function using np.square instead of tf.square:
def mysquare(x, name=None):
with ops.op_scope([x], name, "Mysquare") as name:
sqr_x = py_func(np.square,
[x],
[tf.float32],
name=name,
grad=_MySquareGrad) # <-- here's the call to the gradient
return sqr_x[0]
# Actual gradient:
def _MySquareGrad(op, grad):
x = op.inputs[0]
return grad * 20 * x # add a "small" error just to see the difference:
with tf.Session() as sess:
x = tf.constant([1., 2.])
y = mysquare(x)
tf.initialize_all_variables().run()
print(x.eval(), y.eval(), tf.gradients(y, x)[0].eval())
@samlobel
Copy link

samlobel commented Feb 2, 2019

@Nirzi: The "cannot create weak reference" error comes because we're passing np.square as the argument to the custom py_func, which is then passed to tf.py_func. The problem can be solved by instead defining a new function

def numpy_square(x):
    return np.square(x)

and passing that named function to py_func, instead of the "weak reference" to a numpy function. To be honest I'm not super clear on why, but I think it has to do with how Python does object lookups.

@Musawar71
Copy link

hello guys, i can not print in grad function. How to make sure that my calculation would go fine

@hlx-hub
Copy link

hlx-hub commented Jan 21, 2020

hi,can you show me how to use it to train the networks such as CNN,I am a beginer in this region,so when I try to train it use this method ,there always be a lot of mistakes.thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment