Created
December 14, 2010 18:42
-
-
Save yatsuta/740850 to your computer and use it in GitHub Desktop.
grad
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
def pderiv(f, i): | |
e = 1e-10 | |
init_h = 0.1 | |
def deriv_exp(arg, h): | |
return(f(*arg_h(arg, h)) - f(*arg)) / h | |
def arg_h(arg, h): | |
larg = list(arg) | |
return tuple(larg[:i] + | |
[larg[i] + h] + | |
larg[i+1:]) | |
def deriv_rec(arg, h, old_val): | |
val = deriv_exp(arg, h) | |
if abs(val - old_val) < e: | |
return val | |
else: | |
return deriv_rec(arg, h/2, val) | |
def pderiv_f(*arg): | |
return deriv_rec(arg, init_h/2, | |
deriv_exp(arg, init_h)) | |
return pderiv_f | |
def grad(f): | |
def grad_f(*arg): | |
return tuple(pderiv(f, i)(*arg) | |
for i in xrange(len(arg))) | |
return grad_f | |
if __name__ == '__main__': | |
def f(x, y, z): return x + y**2 + z**3 | |
grad_f = grad(f) | |
print "f(x, y, z) = x + y**2 + z**3" | |
print "grad_f(0, 0, 0) = " + str(grad_f(0, 0, 0)) | |
print "grad_f(1, 1, 1) = " + str(grad_f(1, 1, 1)) | |
print "grad_f(1, 2, 3) = " + str(grad_f(1, 2, 3)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment