Created
June 16, 2018 07:19
-
-
Save niranjannitesh/b66e017c852f94b501949c6108a849b2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
def gradient_desc(m, b, data, learning_rate): | |
"""Gradient Descent""" | |
N = len(data) | |
for i in range(N): | |
x = float(data[i][0]) | |
y = float(data[i][1]) | |
error = y - (m * x + b) | |
m += 2 * error * x * learning_rate | |
b += 2 * error * learning_rate | |
return m, b | |
def gradient_desc_runner(data, m, b, learning_rate, iteration): | |
"""Run gradient_desc iteration times""" | |
for i in range(iteration): | |
m, b = gradient_desc(m, b, data, learning_rate) | |
return m, b | |
def total_error(m, b, data): | |
"""Calculate total error of data""" | |
total_error = 0 | |
for i in range(len(data)): | |
x = float(data[i][0]) | |
y = float(data[i][1]) | |
total_error += (y - (m * x + b)) ** 2 | |
return total_error / float(len(data)) | |
def run(): | |
"""Run the program""" | |
# read the dataset | |
with open('data.csv', newline='') as csvfile: | |
data = list(csv.reader(csvfile)) | |
iteration = 100000 | |
learning_rate = 0.0002 | |
m = 0 # inital m | |
b = 0 # inital b | |
print(f"Starting gradient decent at m = {m}, b = {b}, error={total_error(m, b, data)}") | |
print("[*] Started Running Algorithm....") | |
m, b = gradient_desc_runner(data, m, b, learning_rate, iteration) | |
print(f"After running {iteration} iteration m = {m}, b = {b}, error={total_error(m, b, data)}") | |
if __name__ == "__main__": | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment