Skip to content

Instantly share code, notes, and snippets.

@mebusw
Created May 1, 2017 09:54
Show Gist options
  • Save mebusw/372b7c6257a1058eaed3fac151944800 to your computer and use it in GitHub Desktop.
Save mebusw/372b7c6257a1058eaed3fac151944800 to your computer and use it in GitHub Desktop.
Linear Unit of neural network with numpy
#!/usr/local/bin/python
# -*- coding: UTF-8 -*-
import numpy as np
import matplotlib.pyplot as plt
def predict(input_vec, f):
# print (input_vec, W)
return f(np.dot(input_vec, W))
def train(input_vecs, labels, iterations, rate, f):
for _ in xrange(iterations):
one_iteration(input_vecs, labels, rate, f)
def one_iteration(input_vecs, labels, rate, f):
samples = zip(input_vecs, labels)
for (input_vec, label) in samples:
input_vec = np.append(input_vec, [1])
output = predict(input_vec, f)
update_W(input_vec, output, label, rate)
def update_W(input_vec, output, label, rate):
delta = label - output
global W
W = W + rate * delta * input_vec
if __name__ == '__main__':
input_vecs = [[1, 1], [0, 0], [1, 0], [0, 1]]
# 期望的输出列表,注意要与输入一一对应
# [1,1] -> 1, [0,0] -> 0, [1,0] -> 0, [0,1] -> 0
labels = np.array([1, 0, 0, 0])
# 形如w2x2+w1x1+bias, 或w2x2+w1x1+w0
input_num = 2
W = np.array([0.0] * (input_num+1))
train(input_vecs, labels, iterations=10, rate=0.1, f=lambda x: 1 if x > 0 else 0
)
print W
for x in input_vecs:
print 'AND(%s) = %d' % (x, predict(x+[1], f=lambda x: 1 if x > 0 else 0
))
###############
# 形如w1x1+bias, 或w1x1+w0
input_vecs = [[5], [3], [8], [1.4], [10.1]]
# 期望的输出列表,月薪,注意要与输入一一对应
labels = [5500, 2300, 7600, 1800, 11400]
input_num = 1
W = np.array([0.0] * (input_num+1))
train(input_vecs, labels, iterations=10, rate=0.01, f=lambda x: x)
print W
for x in [[3.4], [15], [1.5], [6.3]]:
t = predict(x+[1], f=lambda x: x)
print 'Work %s years, monthly salary = %.2f' % (x, t)
plt.figure(1)
z = np.linspace(0, 12, 100)
plt.plot(z, z * W[0] + W[1])
for (i, l) in zip(input_vecs, labels):
plt.plot(i, l, 'g^')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment