Last active
April 18, 2018 22:59
-
-
Save freedomtowin/c1ec28af771fce457d53dd4107d418a1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scipy.optimize as opt | |
def poly_least_sqs_loss(x,y,w): | |
hypothesis = w[0]*x[:,0:1] + (x[:,1:2])*w[1] + w[2]*(x[:,1:2])**w[3] | |
loss = hypothesis-y | |
return np.sum(loss**2)/len(y) | |
def poly_function(x, *args_): | |
w = np.array(args_).T | |
result = w[0]*x[:,0:1] + (x[:,1:2])*w[1] + w[2]*(x[:,1:2])**w[3] | |
return result | |
def poly_opt_function(*args_): | |
betas,x,y = args_ | |
loss = poly_least_sqs_loss(x,y,betas) | |
return loss | |
for method in ["Nelder-Mead","Powell","CG","BFGS","L-BFGS-B","TNC","COBYLA"]: | |
start_time = time.time() | |
X = train[['sqft_living']].values | |
y = train[["price"]].values | |
noisy = X | |
noisy = np.column_stack((np.ones(noisy.shape[0]),noisy)) | |
actual = y.flatten() | |
initial_guess = np.random.uniform(0,2,(1,noisy.shape[1]+2)) | |
for _ in range(1): | |
initial_guess = opt.minimize(poly_opt_function, initial_guess, | |
(noisy, actual), method=method, | |
options={"ftol": 1E-4}).x | |
estimated_params = initial_guess | |
y_pred = poly_function(noisy,estimated_params) | |
SST = np.sum((actual-np.mean(actual))**2) | |
SSE = np.sum((actual-y_pred)**2) | |
end_time = time.time() | |
print("method:",method) | |
print("time:",end_time-start_time) | |
print("loss:",SSE) | |
print("estimated parameters:",estimated_params) | |
print("train rsquared:",1-SSE/SST) | |
noisy_valid = test[['sqft_living']].values | |
actual_valid = test[['price']].values | |
noisy_valid = np.column_stack((np.ones(noisy_valid.shape[0]),noisy_valid)) | |
y_pred_valid = poly_function(noisy_valid,estimated_params) | |
SST = np.sum((actual_valid-np.mean(actual_valid))**2) | |
SSE = np.sum((actual_valid-y_pred_valid)**2) | |
print("validation rsquared:",1-SSE/SST) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment