Created
July 5, 2021 11:42
-
-
Save Park-Developer/b2ded3dbf231032ee3284bc122d9eaa7 to your computer and use it in GitHub Desktop.
[ML] Gradient Descent
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
%matplotlib inline | |
np.random.seed(0) | |
# y=4X+6을 근사, 임의의 값을 노이즈로 만듬 | |
X=2*np.random.rand(100,1) | |
y=6+4*X+np.random.rand(100,1) | |
# X,y 데이터 세트 산점도로 시각화 | |
plt.scatter(X,y) | |
# 비용 함수 정의 | |
def get_cost(y,y_pred): | |
N=len(y) | |
cost=np.sum(np.square(y-y_pred))/N | |
return cost | |
# w1과 w0을 업데이트 할 w1_update, w0_update를 반환 | |
import numpy as np | |
def get_weight_updates(w1,w0,X,y,learning_rate=0.01): | |
N=len(y) | |
# 먼저 w1_update, w0_update를 각각 w1,w0의 shape와 동일한 크기를 가진 0 값으로 초기화 | |
w1_update=np.zeros_like(w1) | |
w0_update=np.zeros_like(w0) | |
# 예측 배열 계산하고 예측과 실제 값의 차이 계산 | |
y_pred=np.dot(X,w1.T)+w0 | |
diff=y-y_pred | |
# w0_update를 dot 행렬 연산으로 구하기 위해 모두 1값을 가진 행렬 생성 | |
w0_factors=np.ones((N,1)) | |
# w1과 w0을 업데이트할 w1_update와 w0_update 계산 | |
w1_update=-(2/N)*learning_rate*(np.dot(X.T,diff)) | |
w0_update=-(2/N)*learning_rate*(np.dot(w0_factors.T,diff)) | |
return w1_update,w0_update | |
# 입력 인자 iters로 주어진 횟수만큼 반복적으로 w1과 w0을 업데이트 적용함 | |
import numpy as np | |
def gradient_descent_steps(X,y,iters=10000): | |
# w0과 w1을 모두 0으로 초기화 | |
w0=np.zeros((1,1)) | |
w1=np.zeros((1,1)) | |
# 인자로 주어진 iters 만큼 반복적으로 get_weight_updates() 호출해 w1과 w0 업데이트 수행 | |
for ind in range(iters): | |
w1_update,w0_update=get_weight_updates(w1,w0,X,y,learning_rate=0.01) | |
w1=w1-w1_update | |
w0=w0-w0_update | |
return w1,w0 | |
w1,w0=gradient_descent_steps(X,y,iters=1000) | |
print("w1:{0:.3f} w0:{1:.3f}".format(w1[0,0],w0[0,0])) | |
y_pred=w1[0,0]*X+w0 | |
print("Gradient Descent Total Cost: {0:.4f}".format(get_cost(y,y_pred))) | |
plt.scatter(X,y) | |
plt.plot(X,y_pred) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment