Last active
August 1, 2023 00:59
-
-
Save harpone/3453185b41d8d985356cbe5e57d67342 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.python.framework import ops | |
import numpy as np | |
# Define custom py_func which takes also a grad op as argument: | |
def py_func(func, inp, Tout, stateful=True, name=None, grad=None): | |
# Need to generate a unique name to avoid duplicates: | |
rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8)) | |
tf.RegisterGradient(rnd_name)(grad) # see _MySquareGrad for grad example | |
g = tf.get_default_graph() | |
with g.gradient_override_map({"PyFunc": rnd_name}): | |
return tf.py_func(func, inp, Tout, stateful=stateful, name=name) | |
# Def custom square function using np.square instead of tf.square: | |
def mysquare(x, name=None): | |
with ops.op_scope([x], name, "Mysquare") as name: | |
sqr_x = py_func(np.square, | |
[x], | |
[tf.float32], | |
name=name, | |
grad=_MySquareGrad) # <-- here's the call to the gradient | |
return sqr_x[0] | |
# Actual gradient: | |
def _MySquareGrad(op, grad): | |
x = op.inputs[0] | |
return grad * 20 * x # add a "small" error just to see the difference: | |
with tf.Session() as sess: | |
x = tf.constant([1., 2.]) | |
y = mysquare(x) | |
tf.initialize_all_variables().run() | |
print(x.eval(), y.eval(), tf.gradients(y, x)[0].eval()) |
hello guys, i can not print in grad function. How to make sure that my calculation would go fine
hi,can you show me how to use it to train the networks such as CNN,I am a beginer in this region,so when I try to train it use this method ,there always be a lot of mistakes.thanks
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@Nirzi: The "cannot create weak reference" error comes because we're passing
np.square
as the argument to the custompy_func
, which is then passed totf.py_func
. The problem can be solved by instead defining a new functionand passing that named function to
py_func
, instead of the "weak reference" to a numpy function. To be honest I'm not super clear on why, but I think it has to do with how Python does object lookups.