Skip to content

Instantly share code, notes, and snippets.

@th3terrorist
Last active August 18, 2022 19:59
Show Gist options
  • Save th3terrorist/ab51127d082cc32ce85e8403f463e2af to your computer and use it in GitHub Desktop.
Save th3terrorist/ab51127d082cc32ce85e8403f463e2af to your computer and use it in GitHub Desktop.
An overly simplistic gradient descent algorithm for linear regression on a single variable, taken from andrew's ng course
import math, numpy as np
from typing import Tuple, Callable
from numpy.typing import ArrayLike
x_train = np.array([1.0, 2.0])
y_train = np.array([300.0, 500.0])
def compute_cost(x: ArrayLike, y: ArrayLike, w: float, b: float) -> float:
m = x.shape[0]
cost = 0
for i in range(m):
f_wb = w * x[i] + b
cost += (f_wb - y[i]) ** 2
return cost / (2 * m)
def compute_gradient(x: ArrayLike, y: ArrayLike, w: float, b: float) -> Tuple[float, float]:
m = x.shape[0]
dj_dw = 0
dj_db = 0
for i in range(m):
f_wb = w * x[i] + b
dj_dw_i = (f_wb - y[i]) * x[i]
dj_db_i = f_wb - y[i]
dj_dw += dj_dw_i
dj_db += dj_db_i
dj_dw /= m
dj_db /= m
return dj_dw, dj_db
def gradient_descent(x: ArrayLike, y: ArrayLike, w: float, b: float, alpha: float, num_iters: int,
cost_function: Callable[[ArrayLike, ArrayLike, float, float], float],
gradient_function: Callable[[ArrayLike, ArrayLike, float, float], Tuple[float, float]]):
J_history = []
p_history = []
for i in range(num_iters):
dj_dw, dj_db = gradient_function(x, y, w, b)
b = b - alpha * dj_db
w = w - alpha * dj_dw
if i < 10000:
J_history.append(cost_function(x, y, w, b))
p_history.append((w, b))
if i % math.ceil(num_iters/10) == 0:
# Damn this ugly dude
print(f"Iteration {i:4}: Cost {J_history[-i]:0.2e}",
f"dj_dw: {dj_dw:0.3e}, dj_db: {dj_db:0.3e}"
f"w: {w:0.3e}, b:{b:0.5e}")
return w, b, J_history, p_history
def main():
w_init = 0
b_init = 0
iterations = 10000
tmp_alpha = 1.0e-2
w_final, b_final, J_hist, p_hist = gradient_descent(x_train, y_train, w_init, b_init, tmp_alpha, iterations, compute_cost, compute_gradient)
print(f"(w, b) found using GDA: ({w_final: 8.4f}, {b_final:8.4f}")
print(J_hist)
print(p_hist)
print("============================================\n\n\t\tGOOD JOB :^)\n\n============================================")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment