Created
March 31, 2019 11:06
-
-
Save sdoshi579/a81d5b3c88afcda829b74e0e48a86c2d to your computer and use it in GitHub Desktop.
Using Keras to implement CNN model which would recognize hand-written digits
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%pylab inline | |
import os | |
import numpy as np | |
import pandas as pd | |
from scipy.misc import imread | |
from sklearn.metrics import accuracy_score | |
import tensorflow as tf | |
import keras | |
from keras.layers import Dense, Dropout, Flatten | |
from keras.layers import Conv2D, MaxPooling2D | |
from keras.layers import BatchNormalization | |
from keras.optimizers import Adam | |
data_dir = os.path.abspath('path/to/directory/of/dataset') | |
train = pd.read_csv(os.path.join(data_dir, 'Train', 'train.csv')) | |
test = pd.read_csv(os.path.join(data_dir, 'Test.csv')) | |
sample_submission = pd.read_csv(os.path.join(data_dir, 'Sample_Submission.csv')) | |
# dataset is given of images need to transform it into pixel value array | |
temp = [] | |
for img_name in train.filename: | |
image_path = os.path.join(data_dir, 'Train', 'Images', 'train', img_name) | |
img = imread(image_path, flatten=True) | |
img = img.astype('float32') | |
temp.append(img) | |
train_x = np.stack(temp) | |
print(train_x.shape) | |
train_x /= 255.0 | |
train_x = train_x.reshape(-1, 784).astype('float32') | |
print(train_x.shape) | |
temp = [] | |
for img_name in test.filename: | |
image_path = os.path.join(data_dir, 'Train', 'Images', 'test', img_name) | |
img = imread(image_path, flatten=True) | |
img = img.astype('float32') | |
temp.append(img) | |
test_x = np.stack(temp) | |
test_x /= 255.0 | |
test_x = test_x.reshape(-1, 784).astype('float32') | |
train_y = keras.utils.np_utils.to_categorical(train.label.values) | |
split_size = int(train_x.shape[0]*0.7) | |
# split training dataset into train and validation dataset | |
train_x, val_x = train_x[:split_size], train_x[split_size:] | |
train_y, val_y = train_y[:split_size], train_y[split_size:] | |
# reshape data as the keras model requires | |
train_x_temp = train_x.reshape(-1, 28, 28, 1) | |
val_x_temp = val_x.reshape(-1, 28, 28, 1) | |
# define vars | |
input_shape = (784,) | |
input_reshape = (28, 28, 1) | |
hidden_num_units = 2048 | |
hidden_num_units1 = 1024 | |
hidden_num_units2 = 128 | |
output_num_units = 10 | |
epochs = 20 | |
batch_size = 16 | |
model = Sequential([ | |
Conv2D(16, (3, 3), activation='relu', input_shape=input_reshape, padding='same'), | |
BatchNormalization(), | |
Conv2D(16, (3, 3), activation='relu', padding='same'), | |
BatchNormalization(), | |
MaxPooling2D(pool_size=pool_size), | |
Dropout(0.2), | |
Conv2D(32, (3, 3), activation='relu', padding='same'), | |
BatchNormalization(), | |
Conv2D(32, (3, 3), activation='relu', padding='same'), | |
BatchNormalization(), | |
MaxPooling2D(pool_size=pool_size), | |
Dropout(0.2), | |
Conv2D(64, (3, 3), activation='relu', padding='same'), | |
BatchNormalization(), | |
Conv2D(64, (3, 3), activation='relu', padding='same'), | |
BatchNormalization(), | |
MaxPooling2D(pool_size=pool_size), | |
Dropout(0.2), | |
Flatten(), | |
Dense(units=hidden_num_units, activation='relu'), | |
Dropout(0.3), | |
Dense(units=hidden_num_units1, activation='relu'), | |
Dropout(0.3), | |
Dense(units=hidden_num_units2, activation='relu'), | |
Dropout(0.3), | |
Dense(units=output_num_units, input_dim=hidden_num_units, activation='softmax'), | |
]) | |
model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=1e-4), metrics=['accuracy']) | |
trained_model_conv = model.fit(train_x_temp, train_y, epochs=epochs, batch_size=batch_size, | |
validation_data=(val_x_temp, val_y)) | |
test_x_temp = test_x.reshape(-1, 28, 28, 1) | |
pred = model.predict_classes(test_x_temp) | |
sample_submission.filename = test.filename; sample_submission.label = pred | |
sample_submission.to_csv(os.path.join(data_dir, 'sub13.csv'), index=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
omar