Created
January 23, 2025 02:46
-
-
Save budu/b5447e842bbb8996d5e4ffacd1ca4c30 to your computer and use it in GitHub Desktop.
Implementation of statistical regression analysis (OLS and WLS)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env ruby | |
# rubocop:disable all | |
require 'bundler/inline' | |
gemfile do | |
source 'https://rubygems.org' | |
gem 'numo-linalg' | |
gem 'rover-df' | |
end | |
class RegressionAnalysis | |
include Numo::NMath | |
def initialize(df, x_cols, y_col) | |
@df = df | |
@x_cols = x_cols | |
@y_col = y_col | |
prepare_data | |
end | |
def ols(x = @x_data, y = @y_data) | |
coefficients = solve(x, y) # β = (X'X)^(-1)X'y | |
fitted = x.dot(coefficients) # y^=Xβ | |
residuals = y - fitted # r=y−y^ | |
y_mean = y.mean | |
tss = ((y - y_mean)**2).sum | |
rss = (residuals**2).sum | |
r_squared = 1 - (rss / tss) | |
{ model: 'OLS', | |
coefficients: coefficients, | |
residuals: residuals, | |
fitted: fitted, | |
r_squared: r_squared | |
} | |
end | |
def wls | |
calculated_weights = weights_ols ols[:residuals] | |
wls_fit(calculated_weights) | |
end | |
def iterative_wls(max_iterations: 10, tolerance: 1e-7) | |
current_model = ols | |
old_coefficients = current_model[:coefficients].dup | |
i = max_iterations.times do |i| | |
weights = weights_ols(current_model[:residuals]) | |
current_model = wls_fit(@x_data, @y_data, weights) | |
new_coefficients = current_model[:coefficients] | |
coeff_diff = (new_coefficients - old_coefficients).abs.max | |
break i if coeff_diff < tolerance | |
old_coefficients = new_coefficients.dup | |
end | |
current_model.merge( | |
model: 'Iterative WLS', | |
iterations: i, | |
tolerance: | |
) | |
end | |
private | |
def prepare_data | |
@x_data = add_intercept @df[@x_cols].to_numo | |
@y_data = @df[@y_col].to_numo | |
end | |
def add_intercept(x) | |
Numo::DFloat.hstack([Numo::DFloat.ones([x.shape[0], 1]), x]) | |
end | |
def weights_ols(residuals) | |
# OLS to predict log(r_i^2) = X * gamma | |
log_residuals_sq = Numo::NMath.log(residuals**2 + 1e-8) | |
var_model = ols(@x_data, log_residuals_sq) | |
predicted_log_residuals_sq = @x_data.dot(var_model[:coefficients]) | |
predicted_sigma_sq = Numo::NMath.exp(predicted_log_residuals_sq) | |
# Weights: 1/σ^2 | |
1.0 / (predicted_sigma_sq + 1e-8) | |
end | |
def wls_fit(x = @x_data, y= @y_data, weights) | |
# Apply weights and solve | |
w_sqrt = weights**0.5 | |
x_weighted = x * w_sqrt.reshape(x.shape[0], 1) | |
y_weighted = y * w_sqrt | |
coefficients = solve(x_weighted, y_weighted) # β = (X'WX)^(-1)X'Wy | |
# Fitted values and weighted residuals | |
fitted = x.dot(coefficients) | |
residuals = y - fitted | |
weighted_residuals = residuals * w_sqrt | |
# Weighted R-squared | |
weighted_y_mean = (y * weights).sum / weights.sum | |
weighted_tss = (weights * ((y - weighted_y_mean) ** 2)).sum | |
weighted_rss = (weighted_residuals ** 2).sum | |
r_squared = 1 - (weighted_rss / weighted_tss) | |
{ model: 'WLS', | |
coefficients: coefficients, | |
residuals: residuals, | |
fitted: fitted, | |
r_squared: r_squared, | |
weights: weights | |
} | |
end | |
def solve(x, y) | |
xtx = x.transpose.dot(x) | |
xty = x.transpose.dot(y) | |
Numo::Linalg.solve(xtx, xty) | |
end | |
end | |
N = 2000 | |
def sample(n: N) = yield Numo::DFloat.new(n).rand | |
# House price data | |
house_age = sample { _1 * 50 } | |
house_size = sample { _1 * 4000 + 1000 } | |
noise = sample { _1 * house_size / 1000 } | |
house_price = 100_000 + 200 * house_size + noise | |
df = Rover::DataFrame.new( | |
age: house_age.to_a, | |
size: house_size.to_a, | |
price: house_price.to_a | |
) | |
analyzer = RegressionAnalysis.new( | |
df, | |
%i[size age], | |
:price | |
) | |
def print_results(results, precision: 2) | |
coefficients = results[:coefficients].to_a.map { _1.round(precision) } | |
print "#{results[:model]} " \ | |
"coefficients: #{coefficients}, " \ | |
"R²: #{results[:r_squared]}" | |
print ", Iterations: #{results[:iterations]}" if results[:iterations] | |
puts | |
end | |
print_results analyzer.ols | |
print_results analyzer.wls | |
print_results analyzer.iterative_wls |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment