Skip to content

Instantly share code, notes, and snippets.

@j20232
Created October 12, 2021 05:22
Show Gist options
  • Save j20232/c4c4ed1d6fedb0f27d4ff5284f64e7f6 to your computer and use it in GitHub Desktop.
Save j20232/c4c4ed1d6fedb0f27d4ff5284f64e7f6 to your computer and use it in GitHub Desktop.
taylor_approximation.py
import numpy as np
import argparse
def f(x, func_id):
if func_id == 0:
return x[0] * x[1] + np.log(x[0])
elif func_id == 1:
return np.sin(x[0]) + np.cos(x[1])
else:
raise NotImplementedError("err")
def grad(x, func_id):
if func_id == 0:
return np.array([x[1] + 1 / x[0], x[1]])
elif func_id == 1:
return np.array([np.cos(x[0]), -np.sin(x[1])])
else:
raise NotImplementedError("err")
def hessian(x, func_id):
if func_id == 0:
return np.array([[-1 / (x[0] ** 2), 1], [1, 0]])
elif func_id == 1:
return np.array([[-np.sin(x[0]), 0], [0, -np.cos(x[1])]])
else:
raise NotImplementedError("err")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--func_id", "-f", type=int, default=0)
args = parser.parse_args()
func_id = int(args.func_id)
x1 = np.array([1, 1])
x2 = np.array([1.2, 1.2])
print("GT f(x1): ", f(x1, func_id=func_id))
print("GT f(x2): ", f(x2, func_id=func_id))
app_0 = f(x1, func_id=func_id)
app_1 = app_0 + grad(x1, func_id=func_id) @ (x2 - x1)
app_2 = app_1 + 1 / 2 * (x2 - x1).transpose() @ hessian(x1, func_id=func_id) @ (x2 - x1)
print("0-approx: ", app_0, ", loss: ", (app_0 - f(x2, func_id=func_id)) ** 2)
print("1-approx: ", app_1, ", loss: ", (app_1 - f(x2, func_id=func_id)) ** 2)
print("2-approx: ", app_2, ", loss: ", (app_2 - f(x2, func_id=func_id)) ** 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment