Created
June 30, 2023 02:11
-
-
Save ndgnuh/24557116eb3bd50b76862d76760ccee1 to your computer and use it in GitHub Desktop.
[LMDS] hiertext for detection task
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import pickle | |
import json | |
import sys | |
from os import path, makedirs | |
from io import BytesIO | |
from argparse import ArgumentParser | |
from PIL import Image | |
from lmds import GeneratorModule, create_lmds | |
def parse_args(): | |
parser = ArgumentParser() | |
parser.add_argument("hiertext_root") | |
parser.add_argument("output_root") | |
parser.add_argument("--subsets", default="train,validation") | |
parser.add_argument("--force", default=False, action="store_true") | |
parser.add_argument("--no-gt", default=False, action="store_true") | |
return parser.parse_args() | |
def main(): | |
# inputs | |
args = parse_args() | |
# logging | |
logger = logging.getLogger() | |
formatter = logging.Formatter("[%(levelname)s %(asctime)s] %(message)s") | |
handler = logging.StreamHandler(sys.stdout) | |
handler.setFormatter(formatter) | |
logger.setLevel(logging.INFO) | |
logger.addHandler(handler) | |
# Directory | |
root = args.hiertext_root | |
out_root = args.output_root | |
subsets = args.subsets.split(",") | |
image_dir_path = "" | |
ann_dir_path = "gt" | |
makedirs(out_root, exist_ok=args.force) | |
logger.info(f"Data root: {root}") | |
logger.info(f"Output root: {out_root}") | |
for subset in subsets: | |
# Map folder paths | |
image_subset_dir_path = path.join(root, image_dir_path, subset) | |
ann_path = path.join(root, ann_dir_path, f"{subset}.jsonl") | |
output_path = path.join(out_root, subset) | |
logger.info(f"Image path: {image_subset_dir_path}") | |
logger.info(f"Annotation path: {ann_path}") | |
logger.info(f"Output path: {output_path}") | |
# Load annotation | |
with open(ann_path, "r") as f: | |
anns = json.load(f)["annotations"] | |
# Data generator | |
def generator(): | |
for ann in anns: | |
image_path = f"{ann['image_id']}.jpg" | |
image_path = path.join(image_subset_dir_path, image_path) | |
# Load and convert to bin | |
image = Image.open(image_path) | |
width, height = image.size | |
io = BytesIO() | |
image.save(io, "JPEG") | |
image_bin = io.getvalue() | |
io.close() | |
image.close() | |
if args.no_gt: | |
# skip the annotation parsing if no_gt is specified | |
sample = (image_bin, None, None) | |
yield pickle.dumps(sample) | |
continue | |
# Load bounding boxes and classes | |
classes = [] | |
boxes = [] | |
for par in ann["paragraphs"]: | |
# paragraph level bbox | |
box = par["vertices"] | |
class_idx = 2 | |
boxes.append(box) | |
classes.append(class_idx) | |
for line in par["lines"]: | |
# line level paragraphs | |
box = line["vertices"] | |
class_idx = 1 | |
boxes.append(box) | |
classes.append(class_idx) | |
for word in line["words"]: | |
box = word["vertices"] | |
class_idx = 0 | |
boxes.append(box) | |
classes.append(class_idx) | |
# yield | |
sample = (image_bin, boxes, classes) | |
yield pickle.dumps(sample) | |
total = len(anns) | |
logger.info(f"Number of {subset} sample: {total}") | |
gm = GeneratorModule( | |
total=len(anns), data_gen=generator(), output_path=output_path | |
) | |
create_lmds(gm) | |
__name__ == "__main__" and main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment