octave 通用线性回归模型计算
Last active
April 11, 2018 05:44
-
-
Save BadUncleX/06c1e1a2db4778a2f688 to your computer and use it in GitHub Desktop.
octave 通用线性回归模型计算 机器学习
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function [result] = generalCompute(X, y, alpha, num_iters,targetX) | |
X_original = X; % 留一个副本 | |
m = length(y); | |
% 选择参数 | |
theta = zeros(size(X,2) + 1, 1); % 参数 | |
% 数据收缩, 下面的mu和sigma可以在最后预测时用上,对targetX进行收缩 | |
[X mu sigma] = featureGeneralNormalize(X); | |
% Add intercept term to X | 在第一列加上1 | |
X = [ones(m, 1) X]; | |
% 执行梯度算法 | |
[theta, J_history] = gradientGeneral(X, y, theta, alpha, num_iters, targetX); | |
% 绘制点图 | |
figure; | |
plot(X_original, y, 'rx', 'MarkerSize', 10); %% 绘制点图的时候取原始数据 | |
xlabel('x'); | |
ylabel('y '); | |
% Plot the linear fit | |
hold on; % keep previous plot visible | |
%% 这里有问题 ??? 为何是X * theta 而不是X_original * theta | |
plot(X_original, X * theta , '-b'); %% 绘制直线图的时候取原始数据 | |
legend('Training data', 'Linear regression') | |
hold off % don't overlay any more plots on this figure | |
% Plot the convergence graph | 绘制趋势图 | |
figure; | |
plot(1:numel(J_history), J_history, '-b', 'LineWidth', 2); | |
xlabel('Number of iterations'); | |
ylabel('Cost J'); | |
% Grid over which we will calculate J | |
%theta0_vals = linspace(-10, 10, 100); | |
%theta1_vals = linspace(-1, 4, 100); | |
%theta2_vals = linspace(-1, 4, 100); | |
% Surface plot | 曲面图 | |
%figure; | |
%surf(theta, J_history); | |
%xlabel('\theta_0'); ylabel('\theta_1'); | |
% Contour plot | 轮廓图 等高线图 | |
%figure; | |
% Plot J_vals as 15 contours spaced logarithmically between 0.01 and 100 | |
%contour(theta0_vals, theta1_vals, J_vals, logspace(-2, 3, 20)) | |
%xlabel('\theta_0'); ylabel('\theta_1'); | |
%hold on; | |
%plot(theta(1), theta(2), 'rx', 'MarkerSize', 10, 'LineWidth', 2); | |
% 显示预测结果 | |
%% 这里有问题 ??? 结果和直线直觉不符合 | |
%targetX = featureGeneralNormalize(targetX); | |
targetX = (targetX - mu)./sigma; | |
targetX = [1 targetX ]; %% 前面加一列 1 | |
price = targetX * theta; | |
disp('the result is: '),disp(price); | |
result = price; | |
end | |
% ============================================================ | |
%% 梯度下降 | |
function [theta, J_history] = gradientGeneral(X, y, theta_p, alpha, num_iters) | |
% Initialize some useful values | |
m = length(y); % number of training examples | |
J_history = zeros(num_iters, 1); | |
theta = theta_p; | |
for iter = 1:num_iters | |
%vectorized version | |
% below two are equal | |
%delta = ((theta' * X' - y')*X)'; | |
delta = X' * (X * theta - y); | |
%fprintf('Delta : %f', delta); | |
theta = theta - alpha / m * delta; | |
% Save the cost J in every iteration | |
J_history(iter) = computeCostGeneral(X, y, theta); | |
end | |
end | |
% ============================================================ | |
% cost function | |
function J = computeCostGeneral(X, y, theta) | |
% Initialize some useful values | |
m = length(y); % number of training examples | |
% 预测函数 | |
predictions = X * theta ; | |
% the squared errors should be as usual | |
sqrErrors = (predictions - y) .^ 2; | |
% the result of the function is as it is defined regularly. | |
%J = (1/(2*m)) * sum(sqrErrors); | |
%fprintf('J : %f', J); | |
% 向量化的版本 | |
J = 1/(2 * m) * (X * theta - y)' * (X * theta - y); | |
end | |
% ============================================================ | |
%% 维度缩放 | |
function [X_norm, mu, sigma] = featureGeneralNormalize(X) | |
% You need to set these values correctly | |
X_norm = X; | |
%mu = zeros(1, size(X, 2)); | |
%sigma = zeros(1, size(X, 2)); | |
mu = mean(X); | |
sigma = std(X); | |
indicies = 1:size(X, 2); | |
%for i = indicies, | |
% XminusMu = X(:, i) - mu(i); | |
% X_norm(:, i) = XminusMu / sigma(i); | |
%X_norm(:, i) = (X(:,i) - mu(1,i))/sigma(1,i); %% see the blog ,don't understand | |
%end | |
X_norm = (X - mu) ./ sigma; | |
% ============================================================ | |
end | |
% ============================================================ | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment