This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def gradients(f, x, grad_ys=None): | |
''' | |
An easier way of computing gradients in tensorflow. The difference from tf.gradients is | |
* If f is not connected with x in the graph, it will output 0s instead of Nones. This will be more meaningful | |
for computing higher-order gradients. | |
* The output will have the same shape and type as x. If x is a list, it will be a list. If x is a Tensor, it |