Last active
May 13, 2022 13:48
-
-
Save burnpiro/c3835a1f914545f2034f4190b1e83153 to your computer and use it in GitHub Desktop.
Tensorflow 2 custom dataset Sequence
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from data.data_generator import DataGenerator | |
from config import cfg | |
## Create train dataset | |
train_datagen = DataGenerator(file_path=cfg.TRAIN.DATA_PATH, config_path=cfg.TRAIN.ANNOTATION_PATH) | |
## Create validation dataset | |
val_generator = DataGenerator(file_path=cfg.TEST.DATA_PATH, config_path=cfg.TEST.ANNOTATION_PATH, debug=False) | |
model.fit_generator(generator=train_datagen, | |
epochs=cfg.TRAIN.EPOCHS, | |
callbacks=[# your callbacks for TF], | |
shuffle=True, | |
verbose=1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from easydict import EasyDict | |
__C = EasyDict() | |
cfg = __C | |
# create NN dict | |
__C.NN = EasyDict() | |
__C.NN.INPUT_SIZE = 224 | |
# create Train options dict | |
__C.TRAIN = EasyDict() | |
__C.TRAIN.DATA_PATH = "./data/WIDER_train/images/" | |
__C.TRAIN.ANNOTATION_PATH = "./data/wider_face_split/wider_face_train_bbx_gt.txt" | |
__C.TRAIN.BATCH_SIZE = 16 | |
# create VAL options dict | |
__C.VAL = EasyDict() | |
__C.VAL.DATA_PATH = "./data/WIDER_val/images/" | |
__C.VAL.ANNOTATION_PATH = "./data/wider_face_split/wider_face_val_bbx_gt.txt" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import math | |
import numpy as np | |
import tensorflow as tf | |
from config import cfg | |
# Input: [x0, y0, w, h, blur, expression, illumination, invalid, occlusion, pose] | |
# Output: x0, y0, w, h | |
def get_box(data): | |
x0 = int(data[0]) | |
y0 = int(data[1]) | |
w = int(data[2]) | |
h = int(data[3]) | |
return x0, y0, w, h | |
class DataGenerator(tf.keras.utils.Sequence): | |
def __init__(self, file_path, config_path, debug=False): | |
self.boxes = [] | |
self.debug = debug | |
self.data_path = file_path | |
if not os.path.isfile(config_path): | |
print("File path {} does not exist. Exiting...".format(config_path)) | |
sys.exit() | |
if not os.path.isdir(file_path): | |
print("Images folder path {} does not exist. Exiting...".format(file_path)) | |
sys.exit() | |
with open(config_path) as fp: | |
image_name = fp.readline() | |
cnt = 1 | |
while image_name: | |
num_of_obj = int(fp.readline()) | |
for i in range(num_of_obj): | |
obj_box = fp.readline().split(' ') | |
x0, y0, w, h = get_box(obj_box) | |
if w == 0: | |
# remove boxes with no width | |
continue | |
if h == 0: | |
# remove boxes with no height | |
continue | |
self.boxes.append((image_name.strip(), x0, y0, w, h)) | |
if num_of_obj == 0: | |
obj_box = fp.readline().split(' ') | |
x0, y0, w, h = get_box(obj_box) | |
self.boxes.append((image_name.strip(), x0, y0, w, h)) | |
image_name = fp.readline() | |
cnt += 1 | |
def __len__(self): | |
return math.ceil(len(self.boxes) / cfg.TRAIN.BATCH_SIZE) | |
def __getitem__(self, idx): | |
boxes = self.boxes[idx * cfg.TRAIN.BATCH_SIZE:(idx + 1) * cfg.TRAIN.BATCH_SIZE] | |
batch_images = np.zeros((len(boxes), cfg.NN.INPUT_SIZE, cfg.NN.INPUT_SIZE, 3), dtype=np.float32) | |
batch_boxes = np.zeros((len(boxes), cfg.NN.GRID_SIZE, cfg.NN.GRID_SIZE, 5), dtype=np.float32) | |
for i, row in enumerate(boxes): | |
path, x0, y0, w, h = row | |
proc_image = tf.keras.preprocessing.image.load_img(self.data_path + path) | |
image_width = proc_image.width | |
image_height = proc_image.height | |
proc_image = tf.keras.preprocessing.image.load_img(self.data_path + path, | |
target_size=(cfg.NN.INPUT_SIZE, cfg.NN.INPUT_SIZE)) | |
proc_image = tf.keras.preprocessing.image.img_to_array(proc_image) | |
proc_image = np.expand_dims(proc_image, axis=0) | |
proc_image - tf.keras.applications.mobilenet_v2.preprocess_input(proc_image) | |
batch_images[i] = proc_image | |
# make sure none of the points is out of image border | |
x0 = max(x0, 0) | |
y0 = max(y0, 0) | |
x0 = min(x0, image_width) | |
y0 = min(y0, image_height) | |
x_c = (cfg.NN.GRID_SIZE / image_width) * x0 | |
y_c = (cfg.NN.GRID_SIZE / image_height) * y0 | |
floor_y = math.floor(y_c) # handle case when x i on the corner | |
floor_x = math.floor(x_c) # handle case when y i on the corner | |
batch_boxes[i, floor_y, floor_x, 0] = h / image_height | |
batch_boxes[i, floor_y, floor_x, 1] = w / image_width | |
batch_boxes[i, floor_y, floor_x, 2] = y_c - floor_y | |
batch_boxes[i, floor_y, floor_x, 3] = x_c - floor_x | |
batch_boxes[i, floor_y, floor_x, 4] = 1 | |
return batch_images, batch_boxes |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello,
I hope you are doing good.
I downloaded this app that you developed (https://github.com/burnpiro/tiny-face-detection-tensorflow2), the data and so on...
I trained a model without modifying the code and I have a very very poor val_iou value so that the algorithm is not able to detect faces.
Do you have some local changes that you did not push ?