Last active
May 2, 2023 16:03
-
-
Save rygrob/7b23eb1ab13fb6f2c2993f1a10cd02a3 to your computer and use it in GitHub Desktop.
CVX Log2 Approximation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
import cvxpy as cp | |
# broke the degree parameter for the time being. | |
def get_cvx(fn, deg, lo, hi, rel=True, m=500): | |
# create an array of x-values | |
x = np.linspace(lo, hi, m) | |
# compute the true y-values | |
y = fn(x) | |
y_recip = 1 / y | |
# create a vector of coefficients for our polynomial | |
A = cp.Variable() | |
B = cp.Variable() | |
# returns an array: [x**deg, x**(deg - 1), ... x] | |
def X(x): | |
return [x**i for i in range(deg, 0, -1)] | |
# create the objective | |
obj = 0 | |
# this is wrong | |
#for i in range(m): | |
# expr = A @ X(x[i] - 1.0) - y[i] | |
# if rel: expr /= y[i] | |
# obj += cp.max(cp.abs(expr)) | |
# not super tidy, but I think this works correctly | |
if rel: | |
obj = cp.max( cp.abs( cp.multiply((A * (x - 1)**2) + (B * (x - 1)) - y, y_recip) ) ) | |
else: | |
obj = cp.max( cp.abs( ((A * (x - 1)**2) + (B * (x - 1)) - y) ) ) | |
problem = cp.Problem(cp.Minimize(obj), []) | |
problem.solve() | |
return [A.value, B.value] | |
def model(x, a, b): | |
return a * (x - 1.0)**2 + b * (x - 1.0) | |
u = get_cvx(np.log2, 2, 0.75, 1.5, rel=True) | |
v = get_cvx(np.log2, 2, 0.75, 1.5, rel=False) | |
x = np.arange(0.75, 1.5, 0.01) | |
x = x[np.abs(x - 1.0) > 0.001] | |
y = np.log2(x) | |
errf_cvx_mine_rel = (model(x, u[0], u[1]) - y) / y | |
errf_cvx_mine_abs = (model(x, v[0], v[1]) - y) / y | |
errf_cvx_ebay = (model(x, -0.629673, 1.466967) - y) / y | |
plt.plot(x, errf_cvx_mine_rel, label="CVX Mine, Rel.") | |
plt.plot(x, errf_cvx_mine_abs, label="CVX Mine, Abs.") | |
plt.plot(x, errf_cvx_ebay, label="CVX Ebay") | |
plt.xlabel("x") | |
plt.ylabel("rel. err") | |
plt.legend(loc="upper left") | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment