Created
May 1, 2017 09:54
-
-
Save mebusw/372b7c6257a1058eaed3fac151944800 to your computer and use it in GitHub Desktop.
Linear Unit of neural network with numpy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/local/bin/python | |
| # -*- coding: UTF-8 -*- | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| def predict(input_vec, f): | |
| # print (input_vec, W) | |
| return f(np.dot(input_vec, W)) | |
| def train(input_vecs, labels, iterations, rate, f): | |
| for _ in xrange(iterations): | |
| one_iteration(input_vecs, labels, rate, f) | |
| def one_iteration(input_vecs, labels, rate, f): | |
| samples = zip(input_vecs, labels) | |
| for (input_vec, label) in samples: | |
| input_vec = np.append(input_vec, [1]) | |
| output = predict(input_vec, f) | |
| update_W(input_vec, output, label, rate) | |
| def update_W(input_vec, output, label, rate): | |
| delta = label - output | |
| global W | |
| W = W + rate * delta * input_vec | |
| if __name__ == '__main__': | |
| input_vecs = [[1, 1], [0, 0], [1, 0], [0, 1]] | |
| # 期望的输出列表,注意要与输入一一对应 | |
| # [1,1] -> 1, [0,0] -> 0, [1,0] -> 0, [0,1] -> 0 | |
| labels = np.array([1, 0, 0, 0]) | |
| # 形如w2x2+w1x1+bias, 或w2x2+w1x1+w0 | |
| input_num = 2 | |
| W = np.array([0.0] * (input_num+1)) | |
| train(input_vecs, labels, iterations=10, rate=0.1, f=lambda x: 1 if x > 0 else 0 | |
| ) | |
| print W | |
| for x in input_vecs: | |
| print 'AND(%s) = %d' % (x, predict(x+[1], f=lambda x: 1 if x > 0 else 0 | |
| )) | |
| ############### | |
| # 形如w1x1+bias, 或w1x1+w0 | |
| input_vecs = [[5], [3], [8], [1.4], [10.1]] | |
| # 期望的输出列表,月薪,注意要与输入一一对应 | |
| labels = [5500, 2300, 7600, 1800, 11400] | |
| input_num = 1 | |
| W = np.array([0.0] * (input_num+1)) | |
| train(input_vecs, labels, iterations=10, rate=0.01, f=lambda x: x) | |
| print W | |
| for x in [[3.4], [15], [1.5], [6.3]]: | |
| t = predict(x+[1], f=lambda x: x) | |
| print 'Work %s years, monthly salary = %.2f' % (x, t) | |
| plt.figure(1) | |
| z = np.linspace(0, 12, 100) | |
| plt.plot(z, z * W[0] + W[1]) | |
| for (i, l) in zip(input_vecs, labels): | |
| plt.plot(i, l, 'g^') | |
| plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment