Skip to content

Instantly share code, notes, and snippets.

@tamnguyenvan
Last active July 10, 2023 02:37
Show Gist options
  • Select an option

  • Save tamnguyenvan/2b194154b27fa21a939762d65dc802c2 to your computer and use it in GitHub Desktop.

Select an option

Save tamnguyenvan/2b194154b27fa21a939762d65dc802c2 to your computer and use it in GitHub Desktop.
import shutil
import os
from pathlib import Path
def load_files(dir: str, ext: str):
paths = []
for root, dirs, files in os.walk(dir):
for file in files:
if file not in ('.', '..'):
path = os.path.join(root, file)
if os.path.splitext(path)[1] == ext:
paths.append(path)
return paths
data_dir = Path('.')
ann_dir = data_dir / 'Annotations'
image_dir = data_dir / 'Images'
ann_files = load_files(ann_dir, '.json')
import random
import json
from typing import List, Dict
def merge_ann_files(ann_files: List[str]):
images = []
annotations = []
categories = []
image_id = 1
ann_id = 1
image2anns = dict()
for ann_file in ann_files:
ann_data = json.load(open(ann_file))
ann_images = ann_data['images']
ann_anns = ann_data['annotations']
imageid_map = dict()
tmp_images = []
for image in ann_images:
current_image_id = image['id']
imageid_map[current_image_id] = image_id
image['id'] = image_id
image_id += 1
tmp_images.append(image)
for ann in ann_anns:
if 'attributes' in ann and 'Number' in ann['attributes']:
text = ann['attributes']['Number']
if all([ch in '0123456789' for ch in text]):
x, y, w, h = ann['bbox']
if w < 15 or h < 15:
continue
current_image_id = ann['image_id']
img_id = imageid_map[current_image_id]
ann['id'] = ann_id
ann['image_id'] = img_id
ann_id += 1
ann['text'] = text
ann['category_id'] = 1
annotations.append(ann)
if img_id not in image2anns:
image2anns[img_id] = []
image2anns[img_id].append(ann)
for image in tmp_images:
if image['id'] in image2anns:
images.append(image)
# categories = ann_data['categories']
categories = [{"supercategory": "none", "name": 'text', "id": 1}]
print(len(images), len(image2anns))
return images, image2anns, categories
def save_ann_file(images: List, image2anns: Dict, categories: List, out_file: str, split: str):
print(f'Creating annotation file: {out_file}')
annotations = []
new_images = []
for image in images:
image_id = image['id']
new_filename = f'jersey-text/{split}/COCOlike_{split}_{image_id:08d}.jpg'
image['file_name'] = new_filename
new_images.append(image)
images = new_images
for image in images:
image_id = image['id']
anns = image2anns[image_id]
for ann in anns:
ann_id = ann['id']
annotations.append({
'id': ann['id'],
'image_id': ann['image_id'],
'bbox': ann['bbox'],
'segmentation': ann['segmentation'],
'attributes': {'transcription': ann['text'], 'legible': 1, 'language': 'english'},
'area': ann['area'],
'category_id': 1
})
ann_data = dict(
images=images,
annotations=annotations,
categories=[{'supercategory': 'none', 'name': 'text', 'id': 1}]
)
with open(out_file, 'wt') as f:
json.dump(ann_data, f)
def save_images(images: List, image_dir: str, outdir: str, split: str):
# outdir = os.path.join(outdir, f'jersey-text/{split}')
outdir = f'jersey-text/{split}'
os.makedirs(outdir, exist_ok=True)
for image in images:
image_id = image['id']
new_filename = f'COCOlike_{split}_{image_id:08d}.jpg'
src_path = os.path.join(image_dir, image['file_name'])
dst_path = os.path.join(outdir, new_filename)
shutil.copy(src_path, dst_path)
images, image2anns, categories = merge_ann_files(ann_files)
test_size = 0.2
random.seed(12)
num_train = int(len(images) * (1 - test_size))
random.shuffle(images)
train_images, test_images = images[:num_train], images[num_train:]
print(f'Train/Test split: {len(train_images)}/{len(test_images)}')
image_outdir = 'images'
save_images(train_images, image_dir, image_outdir, 'train')
save_images(test_images, image_dir, image_outdir, 'val')
save_ann_file(train_images, image2anns, categories, 'train_annotations.json', 'train')
save_ann_file(test_images, image2anns, categories, 'val_annotations.json', 'val')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment