Skip to content

Instantly share code, notes, and snippets.

@netsatsawat
Created May 28, 2019 09:41
Show Gist options
  • Save netsatsawat/3574c41a8640d0d04576478840495ca5 to your computer and use it in GitHub Desktop.
Save netsatsawat/3574c41a8640d0d04576478840495ca5 to your computer and use it in GitHub Desktop.
First part of CNN tutorial on KMNIST data set
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Dense, Flatten, Conv2D, Dropout
from tensorflow.python.keras.layers import MaxPooling2D, BatchNormalization
from keras.utils.vis_utils import model_to_dot
from keras.utils import plot_model
from tensorflow.python.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score, accuracy_score
from sklearn.metrics import confusion_matrix
from tensorflow.keras.callbacks import Callback
IMG_ROWS = kmnist_train_image.shape[1]
IMG_COLS = kmnist_train_image.shape[2]
NUM_CLASS = len(np.unique(kmnist_train_label))
def preprocessing_kmnist(img_array, lab_array):
"""
Function to preprocess the predictors array and label array to suitable format
@Args:
img_array: the numpy array of the image
lab_array: the numpy array of the label
Return:
predictor array with channel dimension, and categorical encoding label in numpy array
"""
y_ = keras.utils.to_categorical(lab_array, NUM_CLASS)
num_img = img_array.shape[0]
x_ = img_array.reshape(num_img, IMG_ROWS, IMG_COLS, 1) # we know that this is grayscale
x_ = x_ / 255.
return x_, y_
X, y = preprocessing_kmnist(kmnist_train_image, kmnist_train_label)
X_test, y_test = preprocessing_kmnist(kmnist_test_image, kmnist_test_label)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=SEED)
print('Training data shape: ', X_train.shape)
print('Validation data shape: ', X_val.shape)
print('Testing data shape: ', X_test.shape)
# Training data shape: (48000, 28, 28, 1)
# Validation data shape: (12000, 28, 28, 1)
# Testing data shape: (10000, 28, 28, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment