Last active
January 2, 2022 16:14
-
-
Save XinyueZ/2657b36ef6b59f2521016fb2c9234fef to your computer and use it in GitHub Desktop.
The normal equations
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| disp("Use the normal equations without training to get linear regression weights..."); | |
| function W = randWeights(in, out) | |
| epsilon_init = sqrt(6) / sqrt(in + out); | |
| W = rand(in, out) * 2 * epsilon_init - epsilon_init; | |
| endfunction | |
| % Gen data with m rows and n feature columns. | |
| m = 5000; % m rows data | |
| n = 100; % input features numbers | |
| X = randWeights(m, n); | |
| Y = X * randWeights(n, 1); | |
| % Split 80% of data for traning and the rest for validation. | |
| sp = (m * 0.8); | |
| % Traning data | |
| X_train = X(1:sp, :); | |
| Y_train = Y(1:sp); | |
| % Cross validation data | |
| X_cv = X(m - sp:end, :); | |
| Y_cv = Y(m - sp:end); | |
| function [W] = NE_Training(X, Y) | |
| W = pinv(X' * X) * X' * Y; | |
| endfunction | |
| W = NE_Training(X_train, Y_train); | |
| % Do hypothsis with W on the cross validation data. | |
| hx = X_cv * W; | |
| function checkHypothsis(Y_cv, hx) | |
| diff = norm(Y_cv - hx) / norm(Y_cv + hx); | |
| checkSign = "𐄂"; | |
| if (diff < 1e-9) | |
| checkSign = "✓"; | |
| endif | |
| printf(['\nGradients check difference(less than 1e-9).: %g (%s)\n'], diff, checkSign); | |
| endfunction | |
| checkHypothsis(Y_cv, hx); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment