Skip to content

Instantly share code, notes, and snippets.

@Sycarol
Created October 4, 2021 05:19
Show Gist options
  • Save Sycarol/3edefbd84f16b5978a137819c585a635 to your computer and use it in GitHub Desktop.
Save Sycarol/3edefbd84f16b5978a137819c585a635 to your computer and use it in GitHub Desktop.
Jupyter notebook to run inference on keras's classification model on a folder of images, and display the classification results with images for interpretation.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"source": [
"import tensorflow as tf\n",
"import numpy as np \n",
"from tensorflow.keras.preprocessing import image\n",
"import os\n",
"import matplotlib.pyplot as plt"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Configurations"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"DATA_FOLDER = \"data\"\n",
"# folder containing the classes\n",
"data_dir = \"./data/{}\".format(DATA_FOLDER)\n",
"MODEL_NAME=\"model_name\"\n",
"model_dir=\"./exported_models/{}\".format(MODEL_NAME)\n",
"\n",
"class_indices={0:\"class1\", 1:\"class2\",2:\"class3\"}\n",
"inv_classes = {v: k for k, v in class_indices.items()}\n",
"num_classes=len(class_indices)\n",
"target_size=(128,128)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Function definitions"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"def plot_image(i, predictions_array, true_label, img):\n",
" true_label, img = true_label[i], img[i].astype('uint8')\n",
" plt.grid(False)\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
"\n",
" plt.imshow(img, cmap=plt.cm.binary)\n",
"\n",
" predicted_label = np.argmax(predictions_array)\n",
" if predicted_label == true_label:\n",
" color = 'blue'\n",
" else:\n",
" color = 'red'\n",
"\n",
" plt.xlabel(\"{} {:2.0f}% ({})\".format(class_indices[predicted_label],\n",
" 100*np.max(predictions_array),\n",
" class_indices[true_label]),\n",
" color=color)\n",
"\n",
"def plot_value_array(i, predictions_array, true_label):\n",
" true_label = true_label[i]\n",
" plt.grid(False)\n",
" plt.xticks(range(num_classes))\n",
" plt.yticks([])\n",
" thisplot = plt.bar(range(num_classes), predictions_array, color=\"#777777\")\n",
" plt.ylim([0, 1])\n",
" predicted_label = np.argmax(predictions_array)\n",
"\n",
" thisplot[predicted_label].set_color('red')\n",
" thisplot[true_label].set_color('blue')\n"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Run"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# load model\n",
"model = tf.keras.models.load_model(model_dir)\n",
"\n",
"#loop for all classes\n",
"for (root,dirs,files) in os.walk(data_dir):\n",
" if len(files) >0:\n",
" # load images\n",
" img_list=[]\n",
" test_labels=[]\n",
" test_images=[]\n",
" for f in files:\n",
" if f.endswith(\".jpg\") or f.endswith('.png'):\n",
" img = image.load_img(os.path.join(root,f), target_size=target_size)\n",
" x = image.img_to_array(img)\n",
" test_images.append(x)\n",
" x = np.expand_dims(x, axis=0)\n",
" img_list.append(x)\n",
" test_labels.append(inv_classes[os.path.basename(root)])\n",
" images=np.vstack(img_list)\n",
" # predict\n",
" predictions = model.predict(images)\n",
"\n",
" # show predictions\n",
" num_rows = 5\n",
" num_cols = 5\n",
" num_images = num_rows*num_cols\n",
" plt.figure(figsize=(2*2*num_cols, 2*num_rows))\n",
" for i in range(num_images):\n",
" plt.subplot(num_rows, 2*num_cols, 2*i+1)\n",
" plot_image(i, predictions[i], test_labels, test_images)\n",
" plt.subplot(num_rows, 2*num_cols, 2*i+2)\n",
" plot_value_array(i, predictions[i], test_labels)\n",
" plt.tight_layout()\n",
" plt.show()"
],
"outputs": [],
"metadata": {}
}
],
"metadata": {
"orig_nbformat": 4,
"language_info": {
"name": "python",
"version": "3.6.9",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3.6.9 64-bit ('keras': venv)"
},
"interpreter": {
"hash": "caa3c448db030c47d443f1bf3e4d2b22169e43afde0cc3d48598503fa91d349c"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment