Created
February 1, 2016 15:39
-
-
Save bmander/df70f2a7ab1a8205882f to your computer and use it in GitHub Desktop.
implement digamma function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py | |
index 9e255c2..58338d1 100644 | |
--- a/tensorflow/python/ops/math_grad.py | |
+++ b/tensorflow/python/ops/math_grad.py | |
@@ -266,7 +266,15 @@ def _ErfcGrad(op, grad): | |
@ops.RegisterGradient("Lgamma") | |
def _LgammaGrad(op, grad): # pylint: disable=unused-argument | |
# TODO(ebrevdo): implement digamma | |
- raise NotImplementedError("grad(Lgamma) == Digamma is not implemented") | |
+ | |
+ x = op.inputs[0] | |
+ z = x-1 | |
+ n = constant_op.constant( np.linspace(1,100,100).reshape((100,1)), dtype=grad.dtype ) | |
+ euler_gamma = constant_op.constant( np.euler_gamma, dtype=grad.dtype ) | |
+ with ops.control_dependencies([grad.op]): | |
+ return grad * (-euler_gamma + math_ops.reduce_sum(z/(n*(n+z)),0)) | |
+ | |
+ #raise NotImplementedError("grad(Lgamma) == Digamma is not implemented") | |
@ops.RegisterGradient("Sigmoid") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment