Last active
October 4, 2016 17:48
-
-
Save sschnug/5d30e1787acd9e8bb69bd53368df6939 to your computer and use it in GitHub Desktop.
Comparison of different approaches for SO-question https://goo.gl/MjADHs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import time | |
from scipy.optimize import minimize, nnls | |
from cvxpy import * | |
np.random.seed(1) | |
""" Test-data generator """ | |
def create_test_data(samples, features, noise, loc=10, scale=2.5): | |
m = np.random.normal(loc=loc, scale=scale, size=(samples, features)) | |
x = np.abs(np.random.normal(size=m.shape[0])) | |
y = np.dot(x, m) | |
y += np.random.normal(loc=0, scale=noise) | |
return np.clip(m, 0, np.inf), y | |
""" SLSQP-based approach """ | |
def solve_slsqp(m, y): | |
def loss(x): | |
return np.sum(np.square((np.dot(x, m) - y))) | |
cons = ({'type': 'eq', | |
'fun' : lambda x: np.sum(x) - 1.0}) | |
x0 = np.zeros(m.shape[0]) | |
start = time.time() | |
res = minimize(loss, x0, method='SLSQP', constraints=cons, | |
bounds=[(0, np.inf) for i in range(m.shape[0])], options={'disp': False}) | |
end = time.time() | |
return end-start, res.fun | |
""" General-purpose SOCP """ | |
def solve_socp(x, y): | |
X = Variable(x.shape[0]) | |
constraints = [X >= 0, sum_entries(X) == 1.0] | |
product = x.T * diag(X) | |
diff = sum_entries(product, axis=1) - y | |
problem = Problem(Minimize(norm(diff)), constraints) | |
start = time.time() | |
problem.solve() # only pure solving time is measured! | |
end = time.time() | |
return end-start, problem.value**2 | |
""" Customized NNLS-based approach """ | |
def solve_nnls(x, y): | |
A = np.vstack([x[0] - x[2], x[1] - x[2]]).T | |
start = time.time() | |
(b1, b2), norm = nnls(A, y - x[2]) | |
b3 = 1 - b1 - b2 | |
end = time.time() | |
return end-start, norm**2 | |
""" Benchmark """ | |
N_ITERS = 5 | |
slsqp_results = [] | |
socp_results = [] | |
nnls_results = [] | |
for i in range(N_ITERS): | |
print('it: ', i) | |
x, y = create_test_data(3, 100000, 5) | |
slsqp_results.append(solve_slsqp(x,y)) | |
socp_results.append(solve_socp(x,y)) | |
nnls_results.append(solve_nnls(x,y)) | |
for i in range(N_ITERS): | |
print(slsqp_results[i]) | |
print(socp_results[i]) | |
print(nnls_results[i]) | |
print('avg(slsqp): ', sum(map(lambda x: x[0], slsqp_results))) | |
print('avg(socp): ', sum(map(lambda x: x[0], socp_results))) | |
print('avg(nnls): ', sum(map(lambda x: x[0], nnls_results))) | |
# ('it: ', 0) | |
# ('it: ', 1) | |
# ('it: ', 2) | |
# ('it: ', 3) | |
# ('it: ', 4) | |
# (0.1381521224975586, 10305754.464317337) | |
# (1.1998240947723389, 10305754.463383121) | |
# (0.0032958984375, 10305754.46431742) | |
# (0.2427079677581787, 29139513.869018022) | |
# (1.5273869037628174, 29139513.903798968) | |
# (0.0025339126586914062, 29068672.21712841) | |
# (0.04007601737976074, 50527506.485803545) | |
# (1.9575321674346924, 50527506.08581031) | |
# (0.0026099681854248047, 50527506.48580391) | |
# (0.019935131072998047, 2540007.8313748916) | |
# (1.1613759994506836, 2540007.831291762) | |
# (0.0031058788299560547, 2540007.8313743635) | |
# (0.6307311058044434, 20703057.561957948) | |
# (1.2694158554077148, 21057781.352033652) | |
# (0.0025250911712646484, 21057781.352120947) | |
# ('avg(slsqp): ', 1.0716023445129395) | |
# ('avg(socp): ', 7.115535020828247) | |
# ('avg(nnls): ', 0.014070749282836914) | |
# NOW TRY 250x250 -> socp now much better than slsqp (but somehow huge diffs in objectives) | |
# x, y = create_test_data(250, 250, 5) | |
# ('it: ', 0) | |
# ('it: ', 1) | |
# ('it: ', 2) | |
# ('it: ', 3) | |
# ('it: ', 4) | |
# (7.564117908477783, 823046842.70197964) | |
# (0.9127869606018066, 899213670.9200472) | |
# (5.963587045669556, 900712224.98805499) | |
# (0.8931920528411865, 1023569003.8219413) | |
# (5.797437906265259, 861403740.39903986) | |
# (0.8841960430145264, 949311480.6180087) | |
# (5.881381034851074, 814168929.89387286) | |
# (0.8888638019561768, 874913376.6625063) | |
# (5.721083879470825, 873559470.83809114) | |
# (0.8824841976165771, 1022278960.7623324) | |
# ('avg(slsqp): ', 30.927607774734497) | |
# ('avg(socp): ', 4.461523056030273) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment