Skip to content

Instantly share code, notes, and snippets.

View zakirangwala's full-sized avatar
🐛
squashing bugs

Zaki Rangwala zakirangwala

🐛
squashing bugs
View GitHub Profile
@zakirangwala
zakirangwala / detect.py
Created December 1, 2020 07:17
Real-Time Prediction - Mask Detection
# Real Time Prediction -> Video Capture
def real_time_prediction():
category_index = label_map_util.create_category_index_from_labelmap(
ANNOTATION_PATH+'/label_map.pbtxt')
cap = cv2.VideoCapture(0)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Make detection
while True:
@zakirangwala
zakirangwala / detect.py
Created December 1, 2020 06:57
Check Model - Mask Detection
def check(image):
image_np = cv2.imread(image)
input_tensor = tf.convert_to_tensor(
np.expand_dims(image_np, 0), dtype=tf.float32)
detections = detect_fn(input_tensor)
category_index = label_map_util.create_category_index_from_labelmap(
ANNOTATION_PATH+'/label_map.pbtxt')
num_detections = int(detections.pop('num_detections'))
@zakirangwala
zakirangwala / detect.py
Created December 1, 2020 06:48
Detection Function - Mask Detection
def detect_fn(image):
detection_model = load_model()
image, shapes = detection_model.preprocess(image)
prediction_dict = detection_model.predict(image, shapes)
detections = detection_model.postprocess(prediction_dict, shapes)
return detections
@zakirangwala
zakirangwala / detect.py
Created December 1, 2020 06:35
Load model - Mask Detection
# Load Model from checkpoints
def load_model():
configs = config_util.get_configs_from_pipeline_file(CONFIG_PATH)
detection_model = model_builder.build(
model_config=configs['model'], is_training=False)
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt.restore(os.path.join(CHECKPOINT_PATH, 'ckpt-9')).expect_partial()
return detection_model
@zakirangwala
zakirangwala / detect.py
Created December 1, 2020 06:19
Pipeline Configuration - Mask Detection
# Configuration
def config():
CONFIG_PATH = MODEL_PATH+'/'+CUSTOM_MODEL_NAME+'/pipeline.config'
config = config_util.get_configs_from_pipeline_file(CONFIG_PATH)
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.io.gfile.GFile(CONFIG_PATH, "r") as f:
proto_str = f.read()
text_format.Merge(proto_str, pipeline_config)
# Modify pipeline
pipeline_config.model.ssd.num_classes = 2
@zakirangwala
zakirangwala / xml_to_csv.py
Created December 1, 2020 05:54
Convert file format
# XML TO CSV
def xml_to_csv(path):
xml_list = []
for xml_file in glob.glob(path + '/*.xml'):
tree = ET.parse(xml_file)
root = tree.getroot()
for member in root.findall('object'):
value = (root.find('filename').text+".jpg",
int(root.find('size')[0].text),
int(root.find('size')[1].text),
@zakirangwala
zakirangwala / detect.py
Created December 1, 2020 05:49
Create Label Map - Mask Detector
# Label Map
def construct_label_map():
labels = [{'name': 'Mask', 'id': 1}, {'name': 'NoMask', 'id': 2}]
with open(ANNOTATION_PATH + '\label_map.pbtxt', 'w') as f:
for label in labels:
f.write('item { \n')
f.write('\tname:\'{}\'\n'.format(label['name']))
f.write('\tid:{}\n'.format(label['id']))
f.write('}\n')
@zakirangwala
zakirangwala / detect.py
Last active December 1, 2020 05:44
Setup environment paths - Mask Detector
# Setup Paths
WORKSPACE_PATH = 'Tensorflow/workspace'
SCRIPTS_PATH = 'Tensorflow/scripts'
APIMODEL_PATH = 'Tensorflow/tensorflow-models/models'
ANNOTATION_PATH = WORKSPACE_PATH+'/annotations'
IMAGE_PATH = WORKSPACE_PATH+'/images'
MODEL_PATH = WORKSPACE_PATH+'/models'
PRETRAINED_MODEL_PATH = WORKSPACE_PATH+'/pre-trained-models'
CONFIG_PATH = MODEL_PATH+'/my_ssd_mobnet/pipeline.config'
CHECKPOINT_PATH = MODEL_PATH+'/my_ssd_mobnet/'
@zakirangwala
zakirangwala / detect.py
Created December 1, 2020 05:38
Mask_Detector - Import Libraries
# Import Libraries
import os
import cv2
import glob
import pandas as pd
import xml.etree.ElementTree as ET
import numpy as np
import tensorflow as tf
from object_detection.utils import config_util
from object_detection.protos import pipeline_pb2
@zakirangwala
zakirangwala / main_method.py
Last active September 23, 2020 16:34
Tutorial Code : Main Method
if __name__ == "__main__":
greet()
city, country, latitude, longitude = get_location()
while True:
query = listen().lower()
if 'stock' in query:
pass
elif 'weather' in query or 'temperature' in query:
pass
elif 'movie' in query or 'documentary' in query: