Skip to content

Instantly share code, notes, and snippets.

@ShawnHymel
Last active February 28, 2024 21:58
Show Gist options
  • Save ShawnHymel/e0b86aa4dbed9fb424f9d0b59ef383c4 to your computer and use it in GitHub Desktop.
Save ShawnHymel/e0b86aa4dbed9fb424f9d0b59ef383c4 to your computer and use it in GitHub Desktop.
Edge Impulse - Classify all images in a directory
#!/usr/bin/env python
import device_patches # Device specific patches for Jetson Nano (needs to be before importing cv2)
import cv2
import os
import sys, getopt
import numpy as np
from edge_impulse_linux.image import ImageImpulseRunner
runner = None
def help():
print('python classify-image.py <path_to_model.eim> <path-to-directory/>')
def main(argv):
try:
opts, args = getopt.getopt(argv, "h", ["--help"])
except getopt.GetoptError:
help()
sys.exit(2)
for opt, arg in opts:
if opt in ('-h', '--help'):
help()
sys.exit()
if len(args) != 2:
help()
sys.exit(2)
model = args[0]
dir_path = os.path.dirname(os.path.realpath(__file__))
modelfile = os.path.join(dir_path, model)
print('MODEL: ' + modelfile)
with ImageImpulseRunner(modelfile) as runner:
try:
model_info = runner.init()
print('Loaded runner for "' + model_info['project']['owner'] + ' / ' + model_info['project']['name'] + '"')
labels = model_info['model_parameters']['labels']
# Create a list of all files in the directory (ignore subdirectories)
dirpath = args[1]
filepaths = [os.path.join(dirpath, f) for f in os.listdir(dirpath) if os.path.isfile(os.path.join(dirpath, f))]
# Perform inference on all files
for filepath in filepaths:
# Open image (skip if not an image)
print(f"Performing inference on {filepath}")
img = cv2.imread(filepath)
if img is None:
print(f"Failed to load image: {filepath}")
continue
# imread returns images in BGR format, so we need to convert to RGB
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# get_features_from_image also takes a crop direction arguments in case you don't have square images
features, cropped = runner.get_features_from_image(img)
# Do inference
res = runner.classify(features)
# Print results
if "classification" in res["result"].keys():
print('Result (%d ms.) ' % (res['timing']['dsp'] + res['timing']['classification']), end='')
for label in labels:
score = res['result']['classification'][label]
print('%s: %.2f\t' % (label, score), end='')
print('', flush=True)
elif "bounding_boxes" in res["result"].keys():
print('Found %d bounding boxes (%d ms.)' % (len(res["result"]["bounding_boxes"]), res['timing']['dsp'] + res['timing']['classification']))
for bb in res["result"]["bounding_boxes"]:
print('\t%s (%.2f): x=%d y=%d w=%d h=%d' % (bb['label'], bb['value'], bb['x'], bb['y'], bb['width'], bb['height']))
cropped = cv2.rectangle(cropped, (bb['x'], bb['y']), (bb['x'] + bb['width'], bb['y'] + bb['height']), (255, 0, 0), 1)
# the image will be resized and cropped, save a copy of the picture here
# so you can see what's being passed into the classifier
#cv2.imwrite('debug.jpg', cv2.cvtColor(cropped, cv2.COLOR_RGB2BGR))
finally:
if (runner):
runner.stop()
if __name__ == "__main__":
main(sys.argv[1:])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment