Skip to content

Instantly share code, notes, and snippets.

@sharmaeklavya2
Last active December 12, 2022 21:37
Show Gist options
  • Save sharmaeklavya2/10d6821e6c3ef95f079ab177c63ea029 to your computer and use it in GitHub Desktop.
Save sharmaeklavya2/10d6821e6c3ef95f079ab177c63ea029 to your computer and use it in GitHub Desktop.
Linear regression on 1 variable with confidence intervals
#!/usr/bin/env python3
"""
Takes a CSV file as input and performs linear regression on the data.
"""
import sys
import ast
import argparse
import numpy as np
def computeRegParams(x, y):
assert len(x.shape) == 1 and len(y.shape) == 1
n = x.shape[0]
assert y.shape[0] == n
xbar, ybar = np.mean(x), np.mean(y)
xnorm, ynorm = x - xbar, y - ybar
sxx = np.sum(np.square(xnorm))
sxy = np.sum(xnorm * ynorm)
syy = np.sum(np.square(ynorm))
slope = sxy / sxx
sse = syy - (sxy ** 2 / sxx)
slopeStd = np.sqrt(sse / (n-2) / sxx)
ybarStd = np.sqrt(sse / (n*(n-2)))
return {'n': n, 'xbar': xbar, 'ybar': ybar, 'slope': slope, 'intercept': ybar - slope * xbar,
'sse': sse, 'slopeStd': slopeStd, 'ybarStd': ybarStd,
'sxx': sxx, 'sxy': sxy, 'syy': syy,
}
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('fpath', help='path to CSV file')
parser.add_argument('--no-plot', dest='plot', action='store_false', default=True)
parser.add_argument('--confidence', type=float, default=0.95)
parser.add_argument('--ci-plot-points', type=int, default=100)
parser.add_argument('--delimiter', default=',')
parser.add_argument('--skip-rows', type=int, default=0)
parser.add_argument('--use-cols', type=ast.literal_eval, default=(0, 1))
parser.add_argument('--max-rows', type=int)
args = parser.parse_args()
assert (len(args.use_cols) == 2 and isinstance(args.use_cols[0], int)
and isinstance(args.use_cols[1], int))
x, y = np.loadtxt(args.fpath, skiprows=args.skip_rows, usecols=args.use_cols,
delimiter=args.delimiter, max_rows=args.max_rows, unpack=True)
params = computeRegParams(x, y)
for k, v in params.items():
print('{}:\t{}'.format(k, v))
try:
from scipy import stats
has_stats = True
except ImportError:
has_stats = False
print('scipy.stats not found', file=sys.stderr)
n, xbar, ybar, slope = params['n'], params['xbar'], params['ybar'], params['slope']
slopeStd, ybarStd = params['slopeStd'], params['ybarStd']
if has_stats and args.confidence:
T = stats.t(df=n-2)
lo, hi = T.interval(args.confidence)
print('slope confidence interval: [{}, {}] = {} ± {}'.format(
slope + lo * slopeStd, slope + hi * slopeStd, slope, hi * slopeStd))
print('ybar confidence interval: [{}, {}] = {} ± {}'.format(
ybar + lo * ybarStd, ybar + hi * ybarStd, ybar, hi * ybarStd))
if args.plot:
try:
import matplotlib.pyplot as plt
has_plt = True
except ImportError:
has_plt = False
print('matplotlib.pyplot not found', file=sys.stderr)
if has_plt:
try:
import seaborn as sns
sns.set()
except ImportError:
pass
xends = np.array([np.min(x), np.max(x)])
plt.plot(x, y, 'b.')
plt.plot(xends, ybar + slope * (xends - xbar), 'r-')
if has_stats and args.confidence and args.ci_plot_points:
plotx = np.linspace(xends[0], xends[1], args.ci_plot_points)
sse, sxx = params['sse'], params['sxx']
plotstd = np.sqrt((1/n + np.square(plotx-xbar)/sxx) * sse / (n-2))
ploty = ybar + slope * (plotx - xbar)
plt.fill_between(plotx, ploty + lo * plotstd, ploty + hi * plotstd,
facecolor='#ff808040')
plt.show()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment