Created
January 10, 2018 05:27
-
-
Save Tathagatd96/034ab733279c8b52a571bf9cb145c2ce to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Created on Mon Dec 04 17:59:48 2017 | |
@author: Tathagat Dasgupta | |
""" | |
# -*- coding: utf-8 -*- | |
""" | |
Created on Sat Dec 02 23:56:30 2017 | |
@author: Tathagat Dasgupta | |
""" | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import matplotlib.cm as cm | |
from urllib import urlretrieve | |
import cPickle as pickle | |
import os | |
import gzip | |
import numpy as np | |
import theano | |
import theano.tensor as T | |
import lasagne | |
from lasagne import layers | |
from sklearn.metrics import classification_report | |
from sklearn.metrics import confusion_matrix | |
import time | |
import cv2 | |
import glob | |
import random | |
start_time = time.clock() | |
emotions = ["neutral", "anger", "contempt", "disgust", "fear", "happy", "sadness", "surprise"] #Emotion list | |
def get_files(emotion): #Define function to get file list, randomly shuffle it and split 80/20 | |
files = glob.glob("dataset\\%s\\*" %emotion) | |
random.shuffle(files) | |
training = files[:int(len(files)*0.8)] #get first 80% of file list | |
prediction = files[-int(len(files)*0.2):] #get last 20% of file list | |
return training, prediction | |
def make_sets(): | |
training_data = [] | |
training_labels = [] | |
prediction_data = [] | |
prediction_labels = [] | |
for emotion in emotions: | |
training, prediction = get_files(emotion) | |
#Append data to training and prediction list, and generate labels 0-7 | |
for item in training: | |
image = cv2.imread(item) #open image | |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) #convert to grayscale | |
training_data.append(gray) #append image array to training data list | |
training_labels.append(emotions.index(emotion)) | |
for item in prediction: #repeat above process for prediction set | |
image = cv2.imread(item) | |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
prediction_data.append(gray) | |
prediction_labels.append(emotions.index(emotion)) | |
return training_data, training_labels, prediction_data, prediction_labels | |
X_train, X_test, y_train, y_test = make_sets() | |
for rad in xrange(len(X_train)): | |
X_train[rad] = np.expand_dims(X_train[rad], axis=0) | |
for rad1 in xrange(len(y_train)): | |
y_train[rad1] = np.expand_dims(y_train[rad1], axis=0) | |
print np.shape(X_train[0]) | |
print np.shape(y_train[0]) | |
batch_size=10 | |
#Conv Net Structure | |
output_size=8 | |
data_size=(None,1,350,350) | |
input_var = T.tensor4(name='inputs') | |
target_var =T.ivector(name='targets') | |
net = {} | |
#Input layer: | |
net['data'] = lasagne.layers.InputLayer(data_size, input_var=input_var) | |
#Convolution + Pooling | |
net['conv1'] = lasagne.layers.Conv2DLayer(net['data'], num_filters=6, filter_size=3) | |
net['pool1'] = lasagne.layers.Pool2DLayer(net['conv1'], pool_size=2) | |
net['conv2'] = lasagne.layers.Conv2DLayer(net['pool1'], num_filters=10, filter_size=4) | |
net['pool2'] = lasagne.layers.Pool2DLayer(net['conv2'], pool_size=2) | |
net['conv3'] = lasagne.layers.Conv2DLayer(net['pool2'], num_filters=20, filter_size=2) | |
net['conv4'] = lasagne.layers.Conv2DLayer(net['conv3'], num_filters=20, filter_size=2) | |
net['conv5'] = lasagne.layers.Conv2DLayer(net['conv4'], num_filters=20, filter_size=2) | |
net['pool3'] = lasagne.layers.Pool2DLayer(net['conv5'], pool_size=2) | |
#Fully-connected | |
net['fc1'] = lasagne.layers.DenseLayer(net['pool3'], num_units=100) | |
net['fc2'] = lasagne.layers.DenseLayer(net['fc1'], num_units=100) | |
#Output layer: | |
net['out'] = lasagne.layers.DenseLayer(net['fc2'], num_units=output_size, | |
nonlinearity=lasagne.nonlinearities.softmax) | |
###Defining the cost function and the update rule | |
#Define hyperparameters. These could also be symbolic variables | |
lr = 1e-2 | |
weight_decay = 1e-5 | |
#Loss function: mean cross-entropy | |
prediction = lasagne.layers.get_output(net['out']) | |
loss = lasagne.objectives.categorical_crossentropy(prediction, target_var) | |
loss = loss.mean() | |
#Also add weight decay to the cost function | |
weightsl2 = lasagne.regularization.regularize_network_params(net['out'], lasagne.regularization.l2) | |
loss += weight_decay * weightsl2 | |
#Get the update rule for Stochastic Gradient Descent with Nesterov Momentum | |
params = lasagne.layers.get_all_params(net['out'], trainable=True) | |
###updates = lasagne.updates.sgd( | |
### loss, params, learning_rate=lr) | |
updates=lasagne.updates.adam(loss,params) | |
###Compiling the training and testing functions | |
train_fn = theano.function([input_var, target_var], loss, updates=updates) | |
test_prediction = lasagne.layers.get_output(net['out'], deterministic=True) | |
test_loss = lasagne.objectives.categorical_crossentropy(test_prediction, | |
target_var) | |
test_loss = test_loss.mean() | |
test_acc = T.mean(T.eq(T.argmax(test_prediction, axis=1), target_var), | |
dtype=theano.config.floatX) | |
val_fn = theano.function([input_var, target_var], [test_loss, test_acc]) | |
get_preds = theano.function([input_var], test_prediction) | |
###Training the model | |
#Run the training function per mini-batches. | |
n_examples = len(X_train) | |
n_batches = n_examples / batch_size | |
epochs=50 | |
for epoch in xrange(epochs): | |
for batch in xrange(n_batches): | |
x_batch = X_train[batch*batch_size: (batch+1) * batch_size] | |
y_batch = y_train[batch*batch_size: (batch+1) * batch_size] | |
print np.shape(x_batch) | |
train_fn(x_batch, y_batch) # This is where the model gets updated | |
###Testing the model | |
loss, acc = val_fn(X_test, y_test) | |
test_error = 1 - acc | |
print('Test error: %f' % test_error) | |
#Computation time | |
print time.clock() - start_time, "seconds" | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment