Last active
December 1, 2019 04:58
-
-
Save gary136/c53827875e828907dd70f1152373a349 to your computer and use it in GitHub Desktop.
gd2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def func(xlist): | |
return 2*(xlist[0]-2.3)**2+3*(xlist[0]-4.1)*(xlist[1]-6.5)+5*(xlist[1]-8.0)**2 | |
# functions | |
def partial_derivative(f, x_list, odr, epsilon = 1e-6, **kargs): | |
h_list = x_list.copy() | |
h_list[odr] = x_list[odr] + epsilon | |
fx = f(x_list, **kargs) | |
fh = f(h_list, **kargs) | |
return (fh - fx) / epsilon | |
def gradient(f, x_list, **kargs): | |
grd = [partial_derivative(f, x_list, odr, **kargs) for odr in range(len(x_list))] | |
return grd | |
def g_d(W, f, epochs, lr): | |
loss_container = np.zeros((epochs+1, 1)) | |
W_container = np.zeros((epochs+1, W.shape[0])) | |
loss_container[0] = f(W) | |
W_container[0] = W | |
for i in range(epochs): | |
'''calculate the gradient''' | |
grd = np.array(gradient(f, x_list=W)) | |
'''updated the weight''' | |
W = W - lr*grd | |
# pass value to container | |
loss_container[i+1] = f(W) | |
W_container[i+1] = W.reshape(-1) | |
return loss_container, W_container | |
def show(W, e, l): | |
loss_container, W_container = g_d(W, f=func, epochs=e, lr=l) | |
fig = plt.figure(figsize=(12,3)) | |
cx0 = fig.add_subplot(121) | |
cx2 = plt.subplot(122) | |
cx0.plot(loss_container, label='Loss', color='C3') | |
cx0.set_xlabel('epochs', fontsize=12) | |
cx0.set_ylabel('loss', fontsize=12) | |
cx0.legend() | |
cx2.plot(W_container[:,0], label='W0') | |
cx2.plot(W_container[:,1], label='W1') | |
cx2.set_xlabel('epochs', fontsize=12) | |
cx2.set_ylabel('w', fontsize=12) | |
cx2.legend() | |
plt.show() | |
print(f'ini_loss = {loss_container[0]}, ini_w = {W_container[0]}') | |
print(f'finl_loss = {loss_container[-1]}, finl_w = {W_container[-1]}') | |
return loss_container, W_container |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment