Created
November 7, 2017 14:28
-
-
Save vinupriyesh/2b8e264ce70b6c7a9d375f9fc6546819 to your computer and use it in GitHub Desktop.
2 Layer neural network to categories images as in dataset http://www.cs.toronto.edu/~kriz/cifar.html
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Dataset is taken from http://www.cs.toronto.edu/~kriz/cifar.html | |
2 layer neural network to categorize a 32x32 pixel color image in 10 different categories | |
@Author : Vinu Priyesh V.A. | |
''' | |
import matplotlib.pyplot as plt | |
import numpy as np | |
def unpickle(file): | |
import pickle | |
with open(file, 'rb') as fo: | |
dict = pickle.load(fo, encoding='bytes') | |
return dict | |
def get_data(file): | |
dict = unpickle(file) | |
#for key in dict.keys(): | |
# print(key) | |
print("Unpacking {}".format(dict[b'batch_label'])) | |
X = np.asarray(dict[b'data'].T).astype("uint8") | |
Yraw = np.asarray(dict[b'labels']) | |
Y = np.zeros((10,10000)) | |
for i in range(10000): | |
Y[Yraw[i],i] = 1 | |
names = np.asarray(dict[b'filenames']) | |
return X,Y,names | |
def visualize_image(X,Y,names,id): | |
rgb = X[:,id] | |
print(rgb.shape) | |
img = rgb.reshape(3,32,32).T | |
print(img.shape) | |
plt.imshow(img) | |
plt.title(names[id]) | |
print(Y[id]) | |
plt.show() | |
def plot_cost_graph(values): | |
plt.plot(values) | |
plt.grid(True) | |
plt.ylim([0,2]) | |
plt.show() | |
########## | |
#Simple sigmoid, it is better to use ReLU instead | |
def sigmoid(z): | |
s = 1 / (1 + np.exp(-z)) | |
return s | |
#Simple tanh | |
def tanh(x): | |
return np.tanh(x) | |
def relu(x): | |
return x * (x > 0) | |
def validate (Y,Y1,m): | |
succ = 0 | |
for i in range(m): | |
if(np.sum(Y[:,i] == Y1[:,i]) == 10): | |
succ+=1 | |
return succ/m*100 | |
#Back prop to get the weight and bias computed using gradient descend | |
def back_prop(m,w1,w2,b1,b2,X,Y,iterations,iterations_capture_freq,learning_rate): | |
train_cost = np.zeros(int(iterations/iterations_capture_freq)) | |
for i in range(iterations): | |
Y1,A1,A2 = forward_prop(m,w1,w2,b1,b2,X) | |
dz2 = A2 - Y | |
#print(i) | |
if(i%iterations_capture_freq==0): | |
logprobs = np.multiply(np.log(A2), Y) + np.multiply((1 - Y), np.log(1 - A2)) | |
cost = - np.sum(logprobs) / m | |
print("cost : {} - {}".format(i,cost)) | |
train_cost[int(i/iterations_capture_freq)] = cost | |
dw2 = (1 / m) * np.dot(dz2,A1.T) | |
db2 = (1 / m) * np.sum(dz2,axis=1,keepdims=True) | |
dz1 = np.dot(w2.T,dz2) * (1-np.power(A1,2)) | |
dw1 = (1 / m) * np.dot(dz1,X.T) | |
db1 = (1 / m) * np.sum(dz1,axis=1,keepdims=True) | |
w1 = w1 - learning_rate * dw1 | |
b1 = b1 - learning_rate * db1 | |
w2 = w2 - learning_rate * dw2 | |
b2 = b2 - learning_rate * db2 | |
return w1,b1,w2,b2,train_cost | |
#Forward prop to get the predictions | |
def forward_prop(m,w1,w2,b1,b2,X): | |
Y = np.zeros((10, m)) | |
z1 = np.dot(w1,X) + b1 | |
A1 = tanh(z1) | |
z2 = np.dot(w2,A1) + b2 | |
A2 = sigmoid(z2) | |
#print(A2.shape,Y.shape) | |
for i in range(m): | |
for j in range(10): | |
Y[j, i] = 1 if A2[j, i] > 0.5 else 0 | |
return Y,A1,A2 | |
def model(X_train,Y_train,w1,w2,b1,b2,m,iterations,iterations_capture_freq,learning_rate): | |
#Training phase | |
w1,b1,w2,b2,train_cost = back_prop(m,w1,w2,b1,b2,X_train,Y_train,iterations,iterations_capture_freq,learning_rate) | |
plot_cost_graph(train_cost) | |
Y1,A1,A2 = forward_prop(m,w1,w2,b1,b2,X_train) | |
P_train = validate(Y_train,Y1,m) | |
print(Y_train[:,1]) | |
print(Y1[:,1]) | |
print("Training accuracy : {}%".format(P_train)) | |
return P_train | |
neurons = 100 | |
m=100 | |
w1 = np.random.randn(neurons,3072) | |
w2 = np.random.randn(10,neurons) | |
b1 = np.zeros((neurons,1)) | |
b2 = np.zeros((1,1)) | |
X,Y,names = get_data('data_batch_1') | |
X = X[:,0:m] | |
Y = Y[:,0:m] | |
#visualize_image(X,Y,names,102) | |
model(X,Y,w1,w2,b1,b2,m,3000,50,0.8) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment