Skip to content

Instantly share code, notes, and snippets.

@harpone
Last active August 1, 2023 00:59
Show Gist options
  • Save harpone/3453185b41d8d985356cbe5e57d67342 to your computer and use it in GitHub Desktop.
Save harpone/3453185b41d8d985356cbe5e57d67342 to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow.python.framework import ops
import numpy as np
# Define custom py_func which takes also a grad op as argument:
def py_func(func, inp, Tout, stateful=True, name=None, grad=None):
# Need to generate a unique name to avoid duplicates:
rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))
tf.RegisterGradient(rnd_name)(grad) # see _MySquareGrad for grad example
g = tf.get_default_graph()
with g.gradient_override_map({"PyFunc": rnd_name}):
return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
# Def custom square function using np.square instead of tf.square:
def mysquare(x, name=None):
with ops.op_scope([x], name, "Mysquare") as name:
sqr_x = py_func(np.square,
[x],
[tf.float32],
name=name,
grad=_MySquareGrad) # <-- here's the call to the gradient
return sqr_x[0]
# Actual gradient:
def _MySquareGrad(op, grad):
x = op.inputs[0]
return grad * 20 * x # add a "small" error just to see the difference:
with tf.Session() as sess:
x = tf.constant([1., 2.])
y = mysquare(x)
tf.initialize_all_variables().run()
print(x.eval(), y.eval(), tf.gradients(y, x)[0].eval())
@kristijanbartol
Copy link

kristijanbartol commented May 3, 2018

A complete minimalistic example with actual gradient updates could also be useful: https://gist.github.com/kristijanbartol/1b7b7c5d431415284217bbf63ca25c66

import tensorflow as tf
from tensorflow.python.framework import ops
import numpy as np
import time

ZERO_TOL = 1e-8
LOSS_TOL = 1e-3
SAMPLES = 100
EPOCHS = 100000

train_input = np.random.rand(SAMPLES)
train_label = 3 * train_input


class MyException(Exception):
    pass


def _my_linear_grad(op, grad):
    # second value is not used - it can be multiplied by zero with no side effects
    return grad * op.inputs[1], grad * 0.


def _my_linear(a, x):
    return (a * x).astype(np.float32)


learning_rate = 1e-3
beta1 = 0.9999

x = tf.placeholder(dtype=tf.float32, shape=(), name='x')
y = tf.placeholder(dtype=tf.float32, shape=(), name='y')

a = tf.get_variable('a', dtype=tf.float32, initializer=1.)
tf_a = tf.get_variable('tf_a', dtype=tf.float32, initializer=1.)

with ops.op_scope([a, x], name="MyLinear") as name:
    # custom gradient op name shouldn't conflict with any other TF op name
    unique_name = 'PyFuncGrad@Unique'
    # using tf.RegisterGradient to set _my_linear_grad function in backward pass for gradient op named rnd_name
    tf.RegisterGradient(unique_name)(_my_linear_grad)

    g = tf.get_default_graph()

    # context manager used to override gradients for nodes created in its block
    with g.gradient_override_map({"PyFunc": unique_name}):
        # my_linear is used for forward pass - my_linear and my_linear_grad are wrapped inside a single TF node
        p = tf.py_func(_my_linear, [a, x], [tf.float32], stateful=True, name=name)

tf_p = tf_a * x

loss = tf.reduce_mean(tf.square(p - y))
tf_loss = tf.reduce_mean(tf.square(tf_p - y))

train_vars = [var for var in tf.trainable_variables()]
optim = tf.train.AdamOptimizer(learning_rate, beta1)

# compute_gradients returns a list so I can just concatenate them to calculate tf_loss, too
grads_and_vars = optim.compute_gradients(loss, var_list=train_vars)
grads_and_vars += optim.compute_gradients(tf_loss, var_list=train_vars)
train_op = optim.apply_gradients(grads_and_vars)

tf.summary.scalar('loss', loss)

with tf.Session() as sess:
    train_writer = tf.summary.FileWriter('board', sess.graph)
    merge = tf.summary.merge_all()

    sess.run(tf.global_variables_initializer())

    try:
        for epoch in range(EPOCHS):
            overall_loss = 0.
            # update using each sample separately
            for i in range(SAMPLES):
                result = sess.run([loss, tf_loss, a, tf_a, merge, train_op], feed_dict={
                    x: train_input[i],
                    y: train_label[i]
                })

                if np.abs(result[0] - result[1]) > ZERO_TOL:
                    print('Invalid update!\nExpected: {}, Actual: {}'.format(result[1], result[0]))
                    raise MyException

                print('epoch: {}, iter: {}, loss: {}\na: {}\n'.format(epoch, i, result[0], result[2]))
                overall_loss += result[0]

            overall_loss /= float(SAMPLES)
            print('overall_loss: {}'.format(overall_loss))
            #time.sleep(2.0)

            # [NOTE] this moment will be delayed a bit as it has to "wait" for the epoch to finish
            if overall_loss < LOSS_TOL:
                print('Found parameter!\n---------------\n')
                break

    except MyException:
        pass

@nicola-calonaci
Copy link

nicola-calonaci commented Sep 4, 2018

Thank you very much guys. All this helped me a lot. Two more questions:

  1. could you say why examples by @IFLED and @adler-j give me the error: 'cannot create weak reference to "numpy.ufunc" object' while with @kristijanbartol it's alright?
  2. is it possible to deal with n-dimensional gradient? In that case how does the '# second value is not used - it can be multiplied by zero with no side effects' thing behave?

Possible solution to 2:
Using a py_func for the custom gradient?

@samlobel
Copy link

samlobel commented Feb 2, 2019

@Nirzi: The "cannot create weak reference" error comes because we're passing np.square as the argument to the custom py_func, which is then passed to tf.py_func. The problem can be solved by instead defining a new function

def numpy_square(x):
    return np.square(x)

and passing that named function to py_func, instead of the "weak reference" to a numpy function. To be honest I'm not super clear on why, but I think it has to do with how Python does object lookups.

@Musawar71
Copy link

hello guys, i can not print in grad function. How to make sure that my calculation would go fine

@hlx-hub
Copy link

hlx-hub commented Jan 21, 2020

hi,can you show me how to use it to train the networks such as CNN,I am a beginer in this region,so when I try to train it use this method ,there always be a lot of mistakes.thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment