Created
October 24, 2020 20:42
-
-
Save willkurt/f29863fc5be9414a2e91d818846079c7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This just a quick example of how creating derivatives easily allows us to | |
think about mathematics in a very different, computationally focused way. | |
In this example we consider the defintion of e as the value of x in | |
f(t) = x^t | |
Where the derivative of f, f', is equal to f. | |
f(t) = f'(t) | |
By comparing the loss when we look at f(0) we can use | |
JAX and Newton's method to "discover" e in a way that is | |
very similar in spirit to the analytical approaches but | |
allows us to solve this problem computationally. | |
This example should show up in Hacking Statistics with Python | |
https://www.countbayesie.com/blog/2020/9/16/writing-the-next-book-and-i-want-you-involved | |
""" | |
import jax.numpy as np | |
from jax import grad | |
# e is defined as the value f(t) = x^t where f' = f | |
# start by defining f | |
def f(x,t): | |
return np.power(x,t) | |
# use JAX to get our derivative | |
d_f_wrt_t = grad(f,argnums=1) | |
# loss is just the difference between these two | |
def loss_f(x): | |
return f(x,0.0) - d_f_wrt_t(x,0.0) | |
# we can now use Newton's method to find teh root of our loss function... | |
d_loss_f = grad(loss_f) | |
guess = 4.0 | |
for _ in range(10): | |
guess -= loss_f(guess)/d_loss_f(guess) | |
# and tada! we found (float32) e! | |
print(guess) | |
#DeviceArray(2.718282, dtype=float32) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment