Created
September 3, 2020 15:18
-
-
Save pranjalAI/3d6304372acf185fd657d85c59acd23c to your computer and use it in GitHub Desktop.
This code will give detected item as output
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
physical_devices = tf.config.experimental.list_physical_devices('GPU') | |
if len(physical_devices) > 0: | |
tf.config.experimental.set_memory_growth(physical_devices[0], True) | |
from absl import app, flags, logging | |
from absl.flags import FLAGS | |
import core.utils as utils | |
from core.yolov4 import filter_boxes | |
from tensorflow.python.saved_model import tag_constants | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
from shutil import copyfile | |
import shutil | |
import os, glob | |
from tensorflow.compat.v1 import ConfigProto | |
from tensorflow.compat.v1 import InteractiveSession | |
framework="tf" #tf, tflite, trt | |
model="yolov4" #yolov3 or yolov4 | |
tiny=True #yolo or yolo-tiny | |
iou=0.45 #iou threshold | |
score=0.25 #score threshold | |
output='./detections/' #path to output folder | |
#def main(): | |
def glass_detector(image_name): | |
image_size=416 | |
imput_image=image_name | |
config = ConfigProto() | |
config.gpu_options.allow_growth = True | |
session = InteractiveSession(config=config) | |
STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS) | |
input_size = image_size | |
images = [imput_image] | |
# load model | |
weights_loaded="./checkpoints/yolo-tiny-416" | |
if framework == 'tflite': | |
interpreter = tf.lite.Interpreter(model_path=weights_loaded) | |
else: | |
saved_model_loaded = tf.saved_model.load(weights_loaded, tags=[tag_constants.SERVING]) | |
# loop through images in list and run Yolov4 model on each | |
for count, image_path in enumerate(images, 1): | |
original_image = cv2.imread(image_path) | |
original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB) | |
image_data = cv2.resize(original_image, (input_size, input_size)) | |
image_data = image_data / 255. | |
images_data = [] | |
for i in range(1): | |
images_data.append(image_data) | |
images_data = np.asarray(images_data).astype(np.float32) | |
if framework == 'tflite': | |
interpreter.allocate_tensors() | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
print(input_details) | |
print(output_details) | |
interpreter.set_tensor(input_details[0]['index'], images_data) | |
interpreter.invoke() | |
pred = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))] | |
if model == 'yolov3' and tiny == True: | |
boxes, pred_conf = filter_boxes(pred[1], pred[0], score_threshold=0.25, input_shape=tf.constant([input_size, input_size])) | |
else: | |
boxes, pred_conf = filter_boxes(pred[0], pred[1], score_threshold=0.25, input_shape=tf.constant([input_size, input_size])) | |
else: | |
infer = saved_model_loaded.signatures['serving_default'] | |
batch_data = tf.constant(images_data) | |
pred_bbox = infer(batch_data) | |
for key, value in pred_bbox.items(): | |
boxes = value[:, :, 0:4] | |
pred_conf = value[:, :, 4:] | |
boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression( | |
boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)), | |
scores=tf.reshape( | |
pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])), | |
max_output_size_per_class=50, | |
max_total_size=50, | |
iou_threshold=iou, | |
score_threshold=score | |
) | |
pred_bbox = [boxes.numpy(), scores.numpy(), classes.numpy(), valid_detections.numpy()] | |
#image = utils.draw_bbox(original_image, pred_bbox) | |
cropped_image = utils.draw_bbox(original_image, pred_bbox) | |
# image = utils.draw_bbox(image_data*255, pred_bbox) | |
image = Image.fromarray(cropped_image.astype(np.uint8)) | |
#if not FLAGS.dont_show: | |
#image.show() | |
image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB) | |
cv2.imwrite(output + 'DetectedGlass' + str(count) + '.jpg', image) | |
return image |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment