Last active
February 28, 2024 21:58
-
-
Save ShawnHymel/e0b86aa4dbed9fb424f9d0b59ef383c4 to your computer and use it in GitHub Desktop.
Edge Impulse - Classify all images in a directory
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import device_patches # Device specific patches for Jetson Nano (needs to be before importing cv2) | |
import cv2 | |
import os | |
import sys, getopt | |
import numpy as np | |
from edge_impulse_linux.image import ImageImpulseRunner | |
runner = None | |
def help(): | |
print('python classify-image.py <path_to_model.eim> <path-to-directory/>') | |
def main(argv): | |
try: | |
opts, args = getopt.getopt(argv, "h", ["--help"]) | |
except getopt.GetoptError: | |
help() | |
sys.exit(2) | |
for opt, arg in opts: | |
if opt in ('-h', '--help'): | |
help() | |
sys.exit() | |
if len(args) != 2: | |
help() | |
sys.exit(2) | |
model = args[0] | |
dir_path = os.path.dirname(os.path.realpath(__file__)) | |
modelfile = os.path.join(dir_path, model) | |
print('MODEL: ' + modelfile) | |
with ImageImpulseRunner(modelfile) as runner: | |
try: | |
model_info = runner.init() | |
print('Loaded runner for "' + model_info['project']['owner'] + ' / ' + model_info['project']['name'] + '"') | |
labels = model_info['model_parameters']['labels'] | |
# Create a list of all files in the directory (ignore subdirectories) | |
dirpath = args[1] | |
filepaths = [os.path.join(dirpath, f) for f in os.listdir(dirpath) if os.path.isfile(os.path.join(dirpath, f))] | |
# Perform inference on all files | |
for filepath in filepaths: | |
# Open image (skip if not an image) | |
print(f"Performing inference on {filepath}") | |
img = cv2.imread(filepath) | |
if img is None: | |
print(f"Failed to load image: {filepath}") | |
continue | |
# imread returns images in BGR format, so we need to convert to RGB | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
# get_features_from_image also takes a crop direction arguments in case you don't have square images | |
features, cropped = runner.get_features_from_image(img) | |
# Do inference | |
res = runner.classify(features) | |
# Print results | |
if "classification" in res["result"].keys(): | |
print('Result (%d ms.) ' % (res['timing']['dsp'] + res['timing']['classification']), end='') | |
for label in labels: | |
score = res['result']['classification'][label] | |
print('%s: %.2f\t' % (label, score), end='') | |
print('', flush=True) | |
elif "bounding_boxes" in res["result"].keys(): | |
print('Found %d bounding boxes (%d ms.)' % (len(res["result"]["bounding_boxes"]), res['timing']['dsp'] + res['timing']['classification'])) | |
for bb in res["result"]["bounding_boxes"]: | |
print('\t%s (%.2f): x=%d y=%d w=%d h=%d' % (bb['label'], bb['value'], bb['x'], bb['y'], bb['width'], bb['height'])) | |
cropped = cv2.rectangle(cropped, (bb['x'], bb['y']), (bb['x'] + bb['width'], bb['y'] + bb['height']), (255, 0, 0), 1) | |
# the image will be resized and cropped, save a copy of the picture here | |
# so you can see what's being passed into the classifier | |
#cv2.imwrite('debug.jpg', cv2.cvtColor(cropped, cv2.COLOR_RGB2BGR)) | |
finally: | |
if (runner): | |
runner.stop() | |
if __name__ == "__main__": | |
main(sys.argv[1:]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment