Created
October 30, 2012 06:46
-
-
Save iamFIREcracker/3978678 to your computer and use it in GitHub Desktop.
Ternary Search in Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 | |
""" | |
This is a module which exports functions implementing the ternary search | |
algorithm: | |
http://en.wikipedia.org/wiki/Ternary_search | |
In particular the module exports functions to minimize/maximize a given function | |
for a certain range, and functions to find the value of an array which | |
minimize/maximize a certain condition. | |
""" | |
from __future__ import print_function | |
from __future__ import division | |
def maximize(func, a, b, precision): | |
""" | |
Find the point in the given range which maximises the input function with an | |
error at most equal to `precision` / 2. | |
The fuction to be maximized should either strictly increasing and then | |
strictly decreasing or vice versa. | |
>>> def x(x): | |
... return x | |
... | |
>>> print("{0:.0f}".format(maximize(x, 0, 10, 1e-3))) | |
10 | |
>>> print("{0:.0f}".format(maximize(abs, -2, 20, 1e-3))) | |
20 | |
""" | |
(left, right) = (a, b) | |
while True: | |
if right - left <= precision: | |
return (left + right) / 2 | |
left_third = ((2 * left) + right) / 3 | |
right_third = (left + (2 * right)) / 3 | |
if func(left_third) < func(right_third): | |
(left, right) = (left_third, right) | |
else: | |
(left, right) = (left, right_third) | |
def minimize(func, a, b, precision): | |
""" | |
Find the point in the given range which minimizes the input function with | |
an error at most equal to `precision` / 2. | |
The fuction internally try to _maximize_ the negative of the input function, | |
using the function `maximize_rage`. Consequently the function to minimize | |
should be either strictly decreasing and then strictly increasing or vice | |
versa. | |
>>> def shifted_parabola(shift): | |
... def parabola(x): | |
... return (x + shift) ** 2 | |
... return parabola | |
... | |
>>> print("{0:.0f}".format(minimize(shifted_parabola(2), -5, 5, 1e-3))) | |
-2 | |
>>> from math import exp | |
>>> print("{0:.0f}".format(minimize(exp, 1, 100, 1e-3))) | |
1 | |
""" | |
def neg(func): | |
def wrapper(*args, **kwargs): | |
return -1 * func(*args, **kwargs) | |
return wrapper | |
return maximize(neg(func), a, b, precision) | |
if __name__ == '__main__': | |
import doctest | |
doctest.testmod() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment