Last active
August 29, 2015 14:18
-
-
Save gbuesing/17d6528aacc556bf1de1 to your computer and use it in GitHub Desktop.
Multivariate linear regression in Ruby - adapted from example from Andrew Ng's Machine Learning Coursera class
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'narray' # gem install narray | |
class LinearRegressor | |
attr_reader :theta, :mean, :std, :cost_history | |
def initialize opts = {} | |
@alpha = opts[:alpha] || 0.01 | |
@iterations = opts[:iterations] || 400 | |
end | |
def fit x, y | |
x, @mean, @std = preprocess x | |
y = NVector[*y] | |
m = x.shape[1] | |
@theta = NVector.float(x.shape[0]) | |
@cost_history = [] | |
cost = compute_cost x, y, m | |
@cost_history << cost | |
log cost, 0 | |
1.upto(@iterations) do |i| | |
@theta -= (@alpha / m) * x.transpose * ((x * @theta) - y) | |
cost = compute_cost x, y, m | |
@cost_history << cost | |
log cost, i | |
end | |
self | |
end | |
def predict x | |
x, _m, _s = preprocess(x, @mean, @std) | |
x * @theta | |
end | |
def fit_normal x, y | |
x, @mean, @std = preprocess x | |
y = NVector[*y] | |
@theta = (x.transpose * x).inverse * (x.transpose * y) | |
self | |
end | |
private | |
def compute_cost x, y, m | |
errors = (x * @theta) - y | |
(1 / (2.0 * m)) * errors**2 | |
end | |
def log c, i | |
puts "[#{i}] err: #{c.to_f.round(4)}" if i % 10 == 0 | |
end | |
def preprocess x, mean = nil, std = nil | |
x = NMatrix.cast(x) | |
x_mean = mean || x.mean(1) | |
x_std = std || x.stddev(1) | |
x_std[x_std.eq(0)] = 1.0 # so we don't divide by 0 | |
x = NMatrix.ref((NArray.ref(x) - x_mean) / x_std) | |
out = add_ones_column x | |
[out, x_std, x_mean] | |
end | |
def add_ones_column m | |
out = NMatrix.float(m.shape[0] + 1, m.shape[1]) | |
out[1..m.shape[0], true] = m | |
out[0, true] = 1 | |
out | |
end | |
end | |
require 'csv' | |
x, y = [], [] | |
CSV.read('ex1data2.txt').each do |row| | |
x << row.slice(0,row.length-1).map(&:to_f) | |
y << row.last.to_f | |
end | |
reg = LinearRegressor.new | |
reg.fit x, y | |
puts "Theta:" | |
p reg.theta | |
samples = [ | |
[2104,3], | |
[1600,3], | |
[2400,3] | |
] | |
puts "Predictions:" | |
p reg.predict samples | |
require 'gnuplot' | |
Gnuplot.open do |gp| | |
Gnuplot::Plot.new( gp ) do |plot| | |
plot.title "Cost history" | |
plot.xlabel "iteration" | |
plot.ylabel "error" | |
plot.terminal "png" | |
plot.output "cost_history.png" | |
x = 0.upto(400).to_a | |
y = reg.cost_history | |
plot.data << Gnuplot::DataSet.new( [x, y] ) do |ds| | |
ds.with = "lines" | |
ds.notitle | |
end | |
end | |
end | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2104,3,399900 | |
1600,3,329900 | |
2400,3,369000 | |
1416,2,232000 | |
3000,4,539900 | |
1985,4,299900 | |
1534,3,314900 | |
1427,3,198999 | |
1380,3,212000 | |
1494,3,242500 | |
1940,4,239999 | |
2000,3,347000 | |
1890,3,329999 | |
4478,5,699900 | |
1268,3,259900 | |
2300,4,449900 | |
1320,2,299900 | |
1236,3,199900 | |
2609,4,499998 | |
3031,4,599000 | |
1767,3,252900 | |
1888,2,255000 | |
1604,3,242900 | |
1962,4,259900 | |
3890,3,573900 | |
1100,3,249900 | |
1458,3,464500 | |
2526,3,469000 | |
2200,3,475000 | |
2637,3,299900 | |
1839,2,349900 | |
1000,1,169900 | |
2040,4,314900 | |
3137,3,579900 | |
1811,4,285900 | |
1437,3,249900 | |
1239,3,229900 | |
2132,4,345000 | |
4215,4,549000 | |
2162,4,287000 | |
1664,2,368500 | |
2238,3,329900 | |
2567,4,314000 | |
1200,3,299000 | |
852,2,179900 | |
1852,4,299900 | |
1203,3,239500 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment