Skip to content

Instantly share code, notes, and snippets.

@wware
Last active August 29, 2015 14:08
Show Gist options
  • Save wware/5aec5296b941890af829 to your computer and use it in GitHub Desktop.
Save wware/5aec5296b941890af829 to your computer and use it in GitHub Desktop.
class LinearRegression:
"""
>>> lr = LinearRegression()
>>> lst = [(0., 0.)]
>>> lr.add(lst[0][0], lst[0][1])
>>> lr.get()
>>> lr.error()
>>> for nxt in [(1., 1.), (2., 2.), (4., 3.), (-1., -0.5)]:
... lst.append(nxt)
... lr.add(nxt[0], nxt[1])
... m, b = lr.get()
... print len(lr), (m, b), lr.error()
2 (1.0, -0.0) 0.0
3 (1.0, -0.0) 0.0
4 (0.7428571428571429, 0.2) 0.171428571429
5 (0.7364864864864865, 0.21621621621621623) 0.172297297297
>>> import pprint
>>> pprint.pprint([(x, y, m * x + b) for x, y in lst])
[(0.0, 0.0, 0.21621621621621623),
(1.0, 1.0, 0.9527027027027027),
(2.0, 2.0, 1.6891891891891893),
(4.0, 3.0, 3.1621621621621623),
(-1.0, -0.5, -0.5202702702702703)]
"""
def __init__(self):
self._count = 0
self._sum_x = 0.
self._sum_y = 0.
self._sum_xx = 0.
self._sum_xy = 0.
self._sum_yy = 0.
self._last_computed = -1
def __len__(self):
return self._count
def add(self, x, y):
self._count += 1
self._sum_x += x
self._sum_y += y
self._sum_xx += x ** 2
self._sum_xy += x * y
self._sum_yy += y ** 2
def get(self):
if self._count < 2:
return None
if self._count > self._last_computed:
det = self._sum_x ** 2 - self._count * self._sum_xx
assert abs(det) > 1.e-12
self._m = (self._sum_x * self._sum_y - self._count * self._sum_xy) / det
self._b = (-self._sum_xx * self._sum_y + self._sum_x * self._sum_xy) / det
self._last_computed = self._count
return (self._m, self._b)
def error(self):
if self._count < 2:
return None
m, b = self.get()
return (
self._sum_yy
- 2 * (m * self._sum_xy + b * self._sum_y)
+ m**2 * self._sum_xx + 2 * m * b * self._sum_x + b**2 * self._count
)
if __name__ == "__main__":
import doctest
doctest.testmod()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment