Created
May 7, 2012 10:51
-
-
Save ajasja/2627215 to your computer and use it in GitHub Desktop.
NelderMeadMinimize`Dump`CompiledNelderMead
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* Produces compiled code for the Nelder-Mead algorithm with the objective function inlined. *) | |
(* The objective function takes the form F[parametersToOptimize..,constantParameters] *) | |
NelderMeadMinimize`Dump`CompiledNelderMead[ | |
objectiveFunction_Function | objectiveFunction_CompiledFunction, vars : {__Symbol}, const: {__Symbol}, | |
opts : OptionsPattern[NelderMeadMinimize`Dump`CompiledNelderMead] | |
] := | |
NelderMeadMinimize`Dump`CompiledNelderMead[ | |
objectiveFunction, vars, const | |
opts | |
] = | |
With[{ | |
(* Inlined option values *) | |
historyLength = If[# === Automatic, 10 Length[vars], #] & @ OptionValue["HistoryLength"], | |
reflectRatio = OptionValue["ReflectRatio"], expandRatio = OptionValue["ExpandRatio"], | |
contractRatio = OptionValue["ContractRatio"], shrinkRatio = OptionValue["ShrinkRatio"], | |
(* Other inlined values *) | |
origin = ConstantArray[0., Length[vars]], | |
infinity = $MaxMachineNumber, | |
epsilon = $MachineEpsilon, | |
(* Inlined functions *) | |
f = apply[objectiveFunction, Evaluate[vars~Join~const]], | |
diffs = cumulativeAbsoluteDifferences, | |
(* Options to be passed to Compile *) | |
compileopts = Sequence @@ If[$VersionNumber >= 8, { | |
(* Mathematica 8 and above offer improved behaviour using these options *) | |
RuntimeOptions -> {"Speed", "CompareWithTolerance" -> True, "EvaluateSymbolically" -> False}, | |
CompilationTarget -> OptionValue[CompilationTarget], | |
CompilationOptions->{"ExpressionOptimization"->True, "InlineCompiledFunctions"->Automatic, "InlineExternalDefinitions"->True} | |
}, { | |
(* Ordering is an external call in Mathematica 7 and so needs type information *) | |
{{_Ordering, _Integer, 1}} | |
} | |
] | |
}, | |
Compile[{{pts, _Real, 2}, {cst, _Real, 1},{tol, _Real, 0}, {maxit, _Integer, 0}}, | |
Block[{ | |
(* Housekeeping *) | |
history = Table[infinity, {historyLength}], iteration = maxit, | |
(* Basic quantities *) | |
simplex = pts, vals = f[#~Join~cst]& /@ pts, ordering, | |
(* Calculated points and function values *) | |
centroid = origin, | |
reflectedPoint = origin, reflectedValue = infinity, | |
expandedPoint = origin, expandedValue = infinity, | |
contractedPoint = origin, contractedValue = infinity, | |
(* More readable indices into the simplex array *) | |
best = 1, worst = -1, rest = Rest@Range@Length[pts], | |
(* Operation counts (for debugging purposes) *) | |
evaluations = Length[pts], | |
reflections = 0, expansions = 0, contractions = 0, shrinkages = 0 | |
}, | |
While[ | |
(* Order simplex points by function value *) | |
ordering = Ordering[vals]; | |
vals = vals[[ordering]]; simplex = simplex[[ordering]]; | |
(* Decrement and test iterator *) | |
(iteration--) != 0, | |
(* Check for convergence *) | |
history[[1]] = vals[[best]]; history = RotateLeft[history]; | |
If[diffs[history] <= tol + epsilon diffs[history], | |
Break[] | |
]; | |
(* Find centroid of first (N - 1) points *) | |
centroid = Mean@Most[simplex]; | |
(* Reflect *) | |
reflectedPoint = centroid + reflectRatio (centroid - simplex[[worst]]); | |
reflectedValue = f[reflectedPoint~Join~cst]; ++evaluations; | |
If[vals[[best]] <= reflectedValue < vals[[-2]], | |
vals[[worst]] = reflectedValue; simplex[[worst]] = reflectedPoint; | |
++reflections; Continue[] | |
]; | |
(* Expand *) | |
If[reflectedValue < vals[[best]], | |
expandedPoint = centroid + expandRatio (reflectedPoint - centroid); | |
expandedValue = f[expandedPoint~Join~cst]; ++evaluations; | |
If[expandedValue < reflectedValue, | |
vals[[worst]] = expandedValue; simplex[[worst]] = expandedPoint; | |
++expansions; Continue[], | |
vals[[worst]] = reflectedValue; simplex[[worst]] = reflectedPoint; | |
++reflections; Continue[] | |
]; | |
]; | |
(* Contract *) | |
If[reflectedValue < vals[[worst]], | |
(* Outside contraction *) | |
contractedPoint = centroid + contractRatio (reflectedPoint - centroid); | |
contractedValue = f[contractedPoint~Join~cst]; ++evaluations; | |
If[contractedValue <= reflectedValue, | |
vals[[worst]] = contractedValue; simplex[[worst]] = contractedPoint; | |
++contractions; Continue[] | |
];, | |
(* Inside contraction *) | |
contractedPoint = centroid - contractRatio (centroid - simplex[[worst]]); | |
contractedValue = f[contractedPoint~Join~cst]; ++evaluations; | |
If[contractedValue < vals[[worst]], | |
vals[[worst]] = contractedValue; simplex[[worst]] = contractedPoint; | |
++contractions; Continue[] | |
]; | |
]; | |
(* Shrink *) | |
simplex[[rest]] = simplex[[best]] + shrinkRatio (simplex[[rest]] - simplex[[best]]); | |
vals[[rest]] = f /@ simplex[[rest]]; | |
evaluations += Length[rest] - 1; | |
++shrinkages; | |
]; | |
(* A call out of the VM is necessary to return the results *) | |
(* results = {vals, simplex, {evaluations, reflections, expansions, contractions, shrinkages}};*) | |
First[simplex]~Join~{evaluations, reflections, expansions, contractions, shrinkages} | |
(*{evaluations, reflections, expansions, contractions, shrinkages}*) | |
], compileopts | |
] | |
]; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment