Created
February 10, 2017 01:45
-
-
Save ryanpeach/a6e989c2f36bee9c926bd1a8cb0c190e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random as rand | |
from functools import partial | |
def twiddle(run, args, p, dp, tol = 0.2, N = 100, logger = None): | |
""" Uses gradient descent to find the optimal value of p as input for function run. | |
run is a function which takes p as an argument and returns an error (with 0 being optimal) as an output. | |
dp is the initial magnitute for each index of p to begin | |
N is the max number of iterations, after which the best value of p is returned. | |
tol is the max error allowed, under which this function will terminate. """ | |
best_err, best_p, best_dp, n = 1000000, None, None, 0 | |
#logger.debug("Best P: {0},\nBest Error: {1}\n".format(best_p, best_err)) | |
while abs(sum(dp)) > tol: | |
# Break if past the max | |
if N != None and n > N: | |
break | |
index = list(range(len(p))) | |
rand.shuffle(index) | |
for i in index: | |
p[i] += dp[i] | |
err = run(tuple(p), *args) | |
if err < best_err: | |
best_err = err | |
best_p = p | |
best_dp = dp | |
dp[i] *= 1.1 | |
if logger != None: | |
logger.debug("P: {0},\nDP: {1},\nRUN(P): {2}".format(p,dp,run(tuple(p), *args))) | |
logger.debug("Best P: {0},\nBest Error: {1}\n".format(best_p, best_err)) | |
else: | |
p[i]-=2*dp[i] | |
err = run(tuple(p), *args) | |
if err < best_err: | |
best_err = err | |
best_p = p | |
best_dp = dp | |
dp[i] *= 1.1 | |
if logger != None: | |
logger.debug("P: {0},\nDP: {1},\nRUN(P): {2}".format(p,dp,run(tuple(p), *args))) | |
logger.debug("Best P: {0},\nBest Error: {1}\n".format(best_p, best_err)) | |
else: | |
p[i] += dp[i] | |
dp[i] *= 0.9 | |
if logger != None: | |
logger.debug("Unsuccessful, Error: {0}".format(err)) | |
n += 1 | |
if logger != None: | |
logger.debug("Best P: {0},\tBest Error: {1}".format(best_p, best_err)) | |
return best_p, best_dp, best_err | |
def randomstart(run, runtime, ranges, dranges, tol = .02, N = 100, logger = None): | |
best_err, best_p, n = 100000, None, 0 | |
while n < runtime or runtime < 0: | |
p = [rand.uniform(a,b) for a,b in ranges] | |
dp = [rand.uniform(a,b) for a,b in dranges] | |
if logger != None: logger.info("Starting P: {0}, \tDP: {1}".format(p,dp)) | |
#try: | |
p, err = twiddle(run, p, dp, tol, N, logger) | |
if err < best_err: | |
best_err = err | |
best_p = p | |
if logger != None: logger.critical("Best P: {0},\tBest Error: {1}\n\n".format(best_p, best_err)) | |
else: | |
if logger != None: logger.info("Error: {0}".format(err)) | |
#except Exception as e: | |
# if logger != None: logger.error("Error Produced: {0},\tP: {1},\tDP: {2}\n".format(e,p,dp)) | |
# n += 1 | |
return best_err, best_p |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment