Last active
December 12, 2022 21:37
-
-
Save sharmaeklavya2/10d6821e6c3ef95f079ab177c63ea029 to your computer and use it in GitHub Desktop.
Linear regression on 1 variable with confidence intervals
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Takes a CSV file as input and performs linear regression on the data. | |
""" | |
import sys | |
import ast | |
import argparse | |
import numpy as np | |
def computeRegParams(x, y): | |
assert len(x.shape) == 1 and len(y.shape) == 1 | |
n = x.shape[0] | |
assert y.shape[0] == n | |
xbar, ybar = np.mean(x), np.mean(y) | |
xnorm, ynorm = x - xbar, y - ybar | |
sxx = np.sum(np.square(xnorm)) | |
sxy = np.sum(xnorm * ynorm) | |
syy = np.sum(np.square(ynorm)) | |
slope = sxy / sxx | |
sse = syy - (sxy ** 2 / sxx) | |
slopeStd = np.sqrt(sse / (n-2) / sxx) | |
ybarStd = np.sqrt(sse / (n*(n-2))) | |
return {'n': n, 'xbar': xbar, 'ybar': ybar, 'slope': slope, 'intercept': ybar - slope * xbar, | |
'sse': sse, 'slopeStd': slopeStd, 'ybarStd': ybarStd, | |
'sxx': sxx, 'sxy': sxy, 'syy': syy, | |
} | |
def main(): | |
parser = argparse.ArgumentParser(description=__doc__) | |
parser.add_argument('fpath', help='path to CSV file') | |
parser.add_argument('--no-plot', dest='plot', action='store_false', default=True) | |
parser.add_argument('--confidence', type=float, default=0.95) | |
parser.add_argument('--ci-plot-points', type=int, default=100) | |
parser.add_argument('--delimiter', default=',') | |
parser.add_argument('--skip-rows', type=int, default=0) | |
parser.add_argument('--use-cols', type=ast.literal_eval, default=(0, 1)) | |
parser.add_argument('--max-rows', type=int) | |
args = parser.parse_args() | |
assert (len(args.use_cols) == 2 and isinstance(args.use_cols[0], int) | |
and isinstance(args.use_cols[1], int)) | |
x, y = np.loadtxt(args.fpath, skiprows=args.skip_rows, usecols=args.use_cols, | |
delimiter=args.delimiter, max_rows=args.max_rows, unpack=True) | |
params = computeRegParams(x, y) | |
for k, v in params.items(): | |
print('{}:\t{}'.format(k, v)) | |
try: | |
from scipy import stats | |
has_stats = True | |
except ImportError: | |
has_stats = False | |
print('scipy.stats not found', file=sys.stderr) | |
n, xbar, ybar, slope = params['n'], params['xbar'], params['ybar'], params['slope'] | |
slopeStd, ybarStd = params['slopeStd'], params['ybarStd'] | |
if has_stats and args.confidence: | |
T = stats.t(df=n-2) | |
lo, hi = T.interval(args.confidence) | |
print('slope confidence interval: [{}, {}] = {} ± {}'.format( | |
slope + lo * slopeStd, slope + hi * slopeStd, slope, hi * slopeStd)) | |
print('ybar confidence interval: [{}, {}] = {} ± {}'.format( | |
ybar + lo * ybarStd, ybar + hi * ybarStd, ybar, hi * ybarStd)) | |
if args.plot: | |
try: | |
import matplotlib.pyplot as plt | |
has_plt = True | |
except ImportError: | |
has_plt = False | |
print('matplotlib.pyplot not found', file=sys.stderr) | |
if has_plt: | |
try: | |
import seaborn as sns | |
sns.set() | |
except ImportError: | |
pass | |
xends = np.array([np.min(x), np.max(x)]) | |
plt.plot(x, y, 'b.') | |
plt.plot(xends, ybar + slope * (xends - xbar), 'r-') | |
if has_stats and args.confidence and args.ci_plot_points: | |
plotx = np.linspace(xends[0], xends[1], args.ci_plot_points) | |
sse, sxx = params['sse'], params['sxx'] | |
plotstd = np.sqrt((1/n + np.square(plotx-xbar)/sxx) * sse / (n-2)) | |
ploty = ybar + slope * (plotx - xbar) | |
plt.fill_between(plotx, ploty + lo * plotstd, ploty + hi * plotstd, | |
facecolor='#ff808040') | |
plt.show() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment