Created
September 9, 2019 15:18
-
-
Save Nikolaj-K/1d5ea0302e18a72193f17880cb6a2176 to your computer and use it in GitHub Desktop.
The gradient descent routine from the Wikipedia page, but with our VM.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Implementation discussed here: | |
https://youtu.be/mkQ1G6OAuMA | |
class VirtualMachine: | |
def __init__(self, update_step, max_num_updates=0, halting_condition=lambda state: False): | |
self.update_step = update_step | |
self.max_num_updates = max_num_updates | |
self.halting_condition = halting_condition | |
self.num_updates = 0 | |
def run(self, state): | |
halt = False | |
while(not halt): | |
state = self.update_step(state) | |
self.num_updates += 1 | |
halt = self.halting_condition(state) or self.num_updates == self.max_num_updates | |
return state | |
""" | |
import time | |
from fractions import Fraction | |
from virtual_machine import VirtualMachine, approx_rational | |
PREC = Fraction(10**-5) | |
STEP_SIZE = Fraction(10**-2) | |
def df(x): | |
""" | |
https://en.wikipedia.org/wiki/Gradient_descent#Python | |
https://www.wolframalpha.com/input/?i=Plot%5Bx%5E4+-+3+*+x%5E3+%2B+2,+%7Bx,+-2,+4%7D%5D | |
f(x) := x**4 - 3 x**3 + 2 | |
""" | |
return 4 * x**3 - 9 * x**2 | |
def take_step(x): | |
return x - STEP_SIZE * df(x) | |
def update_step(state): | |
print('g={} (={}), delta = {}'.format(state['g'], float(state['g']), float(state['g'] - state['prev_g']))) | |
time.sleep(.04) | |
g = state['g'] | |
state['prev_g'] = g | |
state['g'] = approx_rational(take_step(g)) | |
return state | |
def halting_condition(state): | |
delta = state['g'] - state['prev_g'] | |
return abs(delta) < PREC | |
if __name__ == "__main__": | |
vm_gd = VirtualMachine(update_step, halting_condition=halting_condition) | |
state = dict(prev_g=-9001, g=0) # initial guesses | |
vm_gd.run(state) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment