Created
October 4, 2021 05:19
-
-
Save Sycarol/3edefbd84f16b5978a137819c585a635 to your computer and use it in GitHub Desktop.
Jupyter notebook to run inference on keras's classification model on a folder of images, and display the classification results with images for interpretation.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"source": [ | |
"import tensorflow as tf\n", | |
"import numpy as np \n", | |
"from tensorflow.keras.preprocessing import image\n", | |
"import os\n", | |
"import matplotlib.pyplot as plt" | |
], | |
"outputs": [], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Configurations" | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"source": [ | |
"DATA_FOLDER = \"data\"\n", | |
"# folder containing the classes\n", | |
"data_dir = \"./data/{}\".format(DATA_FOLDER)\n", | |
"MODEL_NAME=\"model_name\"\n", | |
"model_dir=\"./exported_models/{}\".format(MODEL_NAME)\n", | |
"\n", | |
"class_indices={0:\"class1\", 1:\"class2\",2:\"class3\"}\n", | |
"inv_classes = {v: k for k, v in class_indices.items()}\n", | |
"num_classes=len(class_indices)\n", | |
"target_size=(128,128)" | |
], | |
"outputs": [], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Function definitions" | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"source": [ | |
"def plot_image(i, predictions_array, true_label, img):\n", | |
" true_label, img = true_label[i], img[i].astype('uint8')\n", | |
" plt.grid(False)\n", | |
" plt.xticks([])\n", | |
" plt.yticks([])\n", | |
"\n", | |
" plt.imshow(img, cmap=plt.cm.binary)\n", | |
"\n", | |
" predicted_label = np.argmax(predictions_array)\n", | |
" if predicted_label == true_label:\n", | |
" color = 'blue'\n", | |
" else:\n", | |
" color = 'red'\n", | |
"\n", | |
" plt.xlabel(\"{} {:2.0f}% ({})\".format(class_indices[predicted_label],\n", | |
" 100*np.max(predictions_array),\n", | |
" class_indices[true_label]),\n", | |
" color=color)\n", | |
"\n", | |
"def plot_value_array(i, predictions_array, true_label):\n", | |
" true_label = true_label[i]\n", | |
" plt.grid(False)\n", | |
" plt.xticks(range(num_classes))\n", | |
" plt.yticks([])\n", | |
" thisplot = plt.bar(range(num_classes), predictions_array, color=\"#777777\")\n", | |
" plt.ylim([0, 1])\n", | |
" predicted_label = np.argmax(predictions_array)\n", | |
"\n", | |
" thisplot[predicted_label].set_color('red')\n", | |
" thisplot[true_label].set_color('blue')\n" | |
], | |
"outputs": [], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Run" | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"source": [ | |
"# load model\n", | |
"model = tf.keras.models.load_model(model_dir)\n", | |
"\n", | |
"#loop for all classes\n", | |
"for (root,dirs,files) in os.walk(data_dir):\n", | |
" if len(files) >0:\n", | |
" # load images\n", | |
" img_list=[]\n", | |
" test_labels=[]\n", | |
" test_images=[]\n", | |
" for f in files:\n", | |
" if f.endswith(\".jpg\") or f.endswith('.png'):\n", | |
" img = image.load_img(os.path.join(root,f), target_size=target_size)\n", | |
" x = image.img_to_array(img)\n", | |
" test_images.append(x)\n", | |
" x = np.expand_dims(x, axis=0)\n", | |
" img_list.append(x)\n", | |
" test_labels.append(inv_classes[os.path.basename(root)])\n", | |
" images=np.vstack(img_list)\n", | |
" # predict\n", | |
" predictions = model.predict(images)\n", | |
"\n", | |
" # show predictions\n", | |
" num_rows = 5\n", | |
" num_cols = 5\n", | |
" num_images = num_rows*num_cols\n", | |
" plt.figure(figsize=(2*2*num_cols, 2*num_rows))\n", | |
" for i in range(num_images):\n", | |
" plt.subplot(num_rows, 2*num_cols, 2*i+1)\n", | |
" plot_image(i, predictions[i], test_labels, test_images)\n", | |
" plt.subplot(num_rows, 2*num_cols, 2*i+2)\n", | |
" plot_value_array(i, predictions[i], test_labels)\n", | |
" plt.tight_layout()\n", | |
" plt.show()" | |
], | |
"outputs": [], | |
"metadata": {} | |
} | |
], | |
"metadata": { | |
"orig_nbformat": 4, | |
"language_info": { | |
"name": "python", | |
"version": "3.6.9", | |
"mimetype": "text/x-python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"pygments_lexer": "ipython3", | |
"nbconvert_exporter": "python", | |
"file_extension": ".py" | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3.6.9 64-bit ('keras': venv)" | |
}, | |
"interpreter": { | |
"hash": "caa3c448db030c47d443f1bf3e4d2b22169e43afde0cc3d48598503fa91d349c" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment