-
-
Save scottb/3548544 to your computer and use it in GitHub Desktop.
ruby implementation of gradient descent linear regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'generator' | |
samples = [ | |
{ :xs => [ 1.0, 0.25], :y => 0.98}, | |
{ :xs => [ 1.0, 0.49], :y => 0.82}, | |
{ :xs => [ 1.0, 0.60], :y => 0.41}, | |
{ :xs => [ 1.0, 0.89], :y => 0.31} | |
] | |
# line is the sum of the dot product of the weight (thetas) | |
# an input (xs) vectors. | |
def line( thetas, xs) | |
thetas.zip( xs).map { |t, x| t * x}.inject( :+) | |
end | |
# the error for a function, f, is the difference between the | |
# expected value, y, and f with a certain parameterization, thetas, | |
# applied to the inputs, xs. | |
# | |
# f needs to be a symbol, :function. | |
def error( f, thetas, in_and_out) | |
xs, y = in_and_out.values_at( :xs, :y) | |
y_hat = method( f).call( thetas, xs) | |
return y_hat - y | |
end | |
# because we want an overall sense of error, we need to sum up | |
# the errors for all samples for a given parameterization; however, | |
# simply adding errors would lead to no error if the first were | |
# -10 and the second were 10. | |
# | |
# Therefore, we square the error. Additionally, we take the average | |
# squared error (because the sample size is fixed, this doesn't affect | |
# the outcome. Finally, we take 1/2 of the value (because it makes the | |
# derivative nicer. Because this is a constant, it doesn't affect the | |
# outcome either. | |
def squared_error( f, thetas, data) | |
data.map { |datum| error( f, thetas, datum) ** 2}.inject( :+) | |
end | |
def mean_squared_error( f, thetas, data) | |
count = data.length() | |
return 0.5 * (1.0 / count) * squared_error( f, thetas, data) | |
end | |
# we want to generate a grid of potential parameter | |
# values and then plot out the MSE for each | |
# set of values. | |
def plot_mse_for_thetas( step, samples) | |
range = Generator.new do |g| | |
start = -3.0 | |
stop = 3.0 | |
current = start | |
while current <= stop do | |
g.yield current | |
current += step | |
end | |
end | |
domain = [] | |
while range.next? | |
domain << range.next | |
end | |
puts "\t#{domain.join( "\t")}" | |
domain.each do |t0| | |
print "#{t0}" | |
domain.each do |t1| | |
mse = mean_squared_error( :line, [t0, t1], samples) | |
print "\t#{ '%.3f' % mse}" | |
end | |
puts "" | |
end | |
end | |
plot_mse_for_thetas( 0.40, samples) | |
# view LaTex here: http://www.codecogs.com/latex/eqneditor.php | |
# according to Andrew Ng's notes, in gradient descent, each theta should be updated | |
# by the following rule: | |
# \theta_j := \theta_j - \alpha \frac{\partial}{\partial \theta_j}MSE(\theta) | |
# where every $\theta_j$ should be updated simultaneously. | |
# the derivative of MSE with respect to \theta_j is: | |
# \frac{1}{m} \sum_i (f(x_i) - y_i)x_{i,j} | |
def calculate_gradient_mse( f, thetas, samples) | |
averager = 1.0 / samples.length() | |
gradients = [] | |
thetas.each_with_index do |theta, index| | |
accum = 0.0 | |
samples.each do |sample| | |
xs = *sample.values_at( :xs) | |
accum += error( f, thetas, sample) * xs[ index] | |
end | |
gradients << averager * accum | |
end | |
gradients | |
end | |
puts calculate_gradient_mse( :line, [-3.0, -3.0], samples).inspect | |
def gradient_descent( f, samples, thetas, alpha) | |
mse = mean_squared_error( f, thetas, samples) | |
diff = 1.0 | |
while diff.abs > 1.0e-12 | |
puts "current MSE is #{mse} with thetas #{thetas.inspect}" | |
gradients = calculate_gradient_mse( f, thetas, samples) | |
changes = gradients.map {|g| - alpha * g} | |
thetas = thetas.zip( changes).map {|a, b| a + b} | |
prev, mse = mse, mean_squared_error( f, thetas, samples) | |
diff = prev - mse | |
end | |
thetas | |
end | |
# the true global minimum is around [1.4, -1.4] | |
# puts gradient_descent( :line, samples, [1.4, -1.8], 0.01).inspect() | |
# puts gradient_descent( :line, samples, [-3.0, 3.0], 0.01).inspect() # doesn't find global. | |
# puts gradient_descent( :line, samples, [3.0, -3.0], 0.01).inspect() # does find global. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment