Created
November 15, 2023 20:41
-
-
Save tkarna/113d890083eaa6c9d08e8ab3c12054bb to your computer and use it in GitHub Desktop.
Face detect CNN training example notebook
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "7af23f2f-4c0d-4dc4-9442-dcadd68ae38d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\"\"\"\n", | |
"Example adapted from https://github.com/aamini/introtodeeplearning lab2 part2\n", | |
"© MIT Introduction to Deep Learning\n", | |
"http://introtodeeplearning.com\n", | |
"\"\"\"\n", | |
"import tensorflow as tf\n", | |
"import matplotlib.pyplot as plt\n", | |
"import functools\n", | |
"import numpy as np\n", | |
"import h5py" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "dca63189-3a1c-4b5a-818c-ca78e6640d5b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# load training data\n", | |
"path_to_training_data = tf.keras.utils.get_file(\n", | |
" 'train_face.h5', 'https://www.dropbox.com/s/hlz8atheyozp1yx/train_face.h5?dl=1')\n", | |
"cache = h5py.File(path_to_training_data, 'r')\n", | |
"images = cache['images'][:]\n", | |
"labels = cache['labels'][:].astype(np.float32)\n", | |
"n_train_samples = images.shape[0]\n", | |
"np.random.seed(4)\n", | |
"train_inds = np.random.permutation(np.arange(n_train_samples))\n", | |
"pos_train_inds = train_inds[labels[train_inds, 0] == 1.0]\n", | |
"neg_train_inds = train_inds[labels[train_inds, 0] != 1.0]\n", | |
"\n", | |
"\n", | |
"def get_batch(n):\n", | |
" selected_pos_inds = np.random.choice(\n", | |
" pos_train_inds, size=n//2, replace=False, p=None)\n", | |
" selected_neg_inds = np.random.choice(\n", | |
" neg_train_inds, size=n//2, replace=False, p=None)\n", | |
" selected_inds = np.concatenate((selected_pos_inds, selected_neg_inds))\n", | |
"\n", | |
" sorted_inds = np.sort(selected_inds)\n", | |
" train_img = (images[sorted_inds, :, :, ::-1]/255.).astype(np.float32)\n", | |
" train_label = labels[sorted_inds, ...]\n", | |
" return train_img, train_label" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "a83c029e-b42f-43fa-b38e-7b71cc3f1755", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# define model\n", | |
"n_filters = 12\n", | |
"\n", | |
"\n", | |
"def make_standard_classifier(n_outputs=1):\n", | |
" Conv2D = functools.partial(\n", | |
" tf.keras.layers.Conv2D, padding='same', activation='relu')\n", | |
" BatchNormalization = tf.keras.layers.BatchNormalization\n", | |
" Flatten = tf.keras.layers.Flatten\n", | |
" Dense = functools.partial(tf.keras.layers.Dense, activation='relu')\n", | |
"\n", | |
" model = tf.keras.Sequential([\n", | |
" Conv2D(filters=1*n_filters, kernel_size=5, strides=2),\n", | |
" BatchNormalization(),\n", | |
"\n", | |
" Conv2D(filters=2*n_filters, kernel_size=5, strides=2),\n", | |
" BatchNormalization(),\n", | |
"\n", | |
" Conv2D(filters=4*n_filters, kernel_size=3, strides=2),\n", | |
" BatchNormalization(),\n", | |
"\n", | |
" Conv2D(filters=6*n_filters, kernel_size=3, strides=2),\n", | |
" BatchNormalization(),\n", | |
"\n", | |
" Flatten(),\n", | |
" Dense(512),\n", | |
" Dense(n_outputs, activation=None),\n", | |
" ])\n", | |
" return model\n", | |
"\n", | |
"\n", | |
"model = make_standard_classifier()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "cb09d82d-010d-4a5b-8aaf-77eaf6ca10d0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"batch_size = 32\n", | |
"num_epochs = 2\n", | |
"learning_rate = 5e-4\n", | |
"optimizer = tf.keras.optimizers.Adam(learning_rate)\n", | |
"loss_history = []\n", | |
"smooth = 0.99\n", | |
"\n", | |
"\n", | |
"@tf.function\n", | |
"def train_step(x, y):\n", | |
" with tf.GradientTape() as tape:\n", | |
" logits = model(x)\n", | |
" loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits)\n", | |
"\n", | |
" grads = tape.gradient(loss, model.trainable_variables)\n", | |
" optimizer.apply_gradients(zip(grads, model.trainable_variables))\n", | |
" return loss\n", | |
"\n", | |
"\n", | |
"for epoch in range(num_epochs):\n", | |
" print(f\"Epoch {epoch}/{num_epochs}\")\n", | |
" for idx in range(n_train_samples//batch_size):\n", | |
" x, y = get_batch(batch_size)\n", | |
" loss = train_step(x, y)\n", | |
"\n", | |
" loss_scalar = loss.numpy().mean()\n", | |
" loss_history.append(\n", | |
" smooth*loss_history[-1] + (1-smooth)*loss_scalar if len(loss_history) > 0 else loss_scalar)\n", | |
" if idx % 400 == 0:\n", | |
" print(f\" batch {idx:4d} loss={loss_history[-1]:.5f}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "5ffda736-18a8-4e47-9ea3-bde9132c6419", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"plt.semilogy(loss_history)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "95482d7c-2c71-47ae-872d-f8ae0c2b89c6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"(batch_x, batch_y) = get_batch(5000)\n", | |
"y_pred = tf.round(tf.nn.sigmoid(model.predict(batch_x)))\n", | |
"acc = tf.reduce_mean(tf.cast(tf.equal(batch_y, y_pred), tf.float32))\n", | |
"print(\"CNN accuracy on training set: {:.4f}\".format(acc.numpy()))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment