Last active
January 13, 2019 11:59
-
-
Save ameya98/c485e11ef149d4adaa242099e6265e85 to your computer and use it in GitHub Desktop.
Bayesian Linear Regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
Bayesian Linear Regression in WebPPL | |
Author: Ameya Daigavane | |
Runnable in WebPPL's online editor at webppl.org | |
*/ | |
/* Training Data */ | |
let x_train = _.range(0, 10, 0.1); | |
let y_train = _.range(20, 0, -0.2); | |
let sample_uniform = function(){ | |
return uniform(0, 10); | |
} | |
/* 100 samples in the range 0 to 10. */ | |
let x_test = repeat(100, sample_uniform); | |
let y_test = map(function(x) {return 20 - 2 * x;}, x_test); | |
/* Linear Regression Model. */ | |
let lr = function() { | |
/* Prior beliefs. */ | |
let posterior_m = uniform(-10, 10); | |
let posterior_intercept = uniform(0, 100); | |
let err_sigma = uniform(0, 0.1); | |
/* Condition on training data. */ | |
mapData({data: x_train}, function(x_sample, ind){ | |
let y_pred = posterior_m * x_sample + posterior_intercept; | |
observe(Gaussian({mu: y_pred, sigma:err_sigma}), y_train[ind]); | |
}); | |
/* Predictions once conditioned. */ | |
let predictions = mapData({data: x_test}, function(x_sample){ | |
return posterior_m * x_sample + posterior_intercept; | |
}); | |
return { | |
m: posterior_m, | |
intercept: posterior_intercept, | |
sigma: err_sigma, | |
test_predictions: predictions, | |
}; | |
}; | |
/* Joint distribution. */ | |
let dist = Infer({method: 'MCMC', samples: 50000}, lr); | |
/* Marginal distributions. */ | |
let dist_m = marginalize(dist, function(d) { return d.m}); | |
let dist_intercept = marginalize(dist, function(d) { return d.intercept}); | |
let dist_sigma = marginalize(dist, function(d) { return d.sigma}); | |
let dist_preds = marginalize(dist, function(d) { return d.test_predictions}); | |
/* Visualize marginal densities. */ | |
viz.density(dist_m, {bounds: [-5, 0]}); | |
viz.density(dist_intercept, {bounds: [0, 100]}); | |
viz.density(dist_sigma, {bounds: [0, 0.5]}); | |
/* MAP value learned. */ | |
let m_learnt = dist_m.MAP().val; | |
let intercept_learnt = dist_intercept.MAP().val; | |
let y_predictions = dist_preds.MAP().val; | |
display("Slope learnt: " + m_learnt); | |
display("Intercept learnt: " + intercept_learnt); | |
/* Residuals over the test set. */ | |
let test_residuals = mapData({data: y_test}, function(__, ind) { | |
return Math.sqrt(Math.pow((y_test[ind] - y_predictions[ind]), 2)); | |
}); | |
let mean_squared_error = sum(test_residuals) / test_residuals.length; | |
display("Mean squared error on test data = " + mean_squared_error); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment