Created
March 29, 2017 11:52
-
-
Save jinyu121/8f3cd4dd1adaed4108cbe754aecfb794 to your computer and use it in GitHub Desktop.
实验室数据集
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import os | |
import re | |
import random | |
import lxml.etree as ElementTree | |
import dicttoxml | |
from xml.dom.minidom import parseString | |
from collections import OrderedDict | |
def dict_to_elem(dictionary): | |
item = ElementTree.Element('Item') | |
for key in dictionary: | |
field = ElementTree.Element(key.strip().replace(' ', '_')) | |
field.text = dictionary[key] | |
item.append(field) | |
return item | |
if "__main__" == __name__: | |
BASE_DIR = "/home/haoyu/VOCdevkit/VOC2007/" | |
MIN_X = 1 | |
MIN_Y = 1 | |
MAX_X = 640 | |
MAX_Y = 480 | |
ground_truth_file = os.path.join(BASE_DIR,"Groundtruth","groundtruth.txt") | |
pattern_ith = re.compile(r"^(\d+):") | |
pattern_roi = re.compile(r"\[(\d+),(\d+),(\d+),(\d+)\]") | |
filename_jar = list() | |
sets = { | |
'train': 0.8, | |
'val': 0.1, | |
'test': 0.1, | |
} | |
with open(ground_truth_file, 'r') as f: | |
for line in f: | |
print(line) | |
ith = pattern_ith.match(line) | |
filename_jar.append(ith.group(1)) | |
rois = pattern_roi.findall(line) | |
AnnotationFile = os.path.join(BASE_DIR,"Annotations", | |
"{}_{}".format("rgb", ith.group(1)) + ".xml") | |
data = { | |
'folder': "VOC2007", | |
'filename': "{}_{}".format("rgb", ith.group(1)) + ".png", | |
'size': { | |
'width': 640, | |
'height': 480, | |
'depth': 3, | |
}, | |
'segmented': 0 | |
} | |
xml = dicttoxml.dicttoxml(OrderedDict(data), attr_type=False, custom_root='annotation') | |
dom = parseString(xml) | |
for roi in rois: | |
x1 = max(int(roi[0]),MIN_X) | |
y1 = max(int(roi[1]),MIN_Y) | |
x2 = min(int(roi[2]),MAX_X) | |
y2 = min(int(roi[3]),MAX_Y) | |
obj = { | |
'name': 'person', | |
'pose': 'Left', | |
'truncated': 1, | |
'difficult': 0, | |
'bndbox': { | |
'xmin': x1, | |
'ymin': y1, | |
'xmax': x2, | |
'ymax': y2, | |
} | |
} | |
assert MIN_X<=int(x1)<=MAX_X ,"{}".format(x1) | |
assert MIN_Y<=int(y1)<=MAX_Y ,"{}".format(y1) | |
assert MIN_X<=int(x2)<=MAX_X ,"{}".format(x2) | |
assert MIN_Y<=int(y2)<=MAX_Y ,"{}".format(y2) | |
assert x1<x2 , "{} {}"%(x1,x2) | |
assert y1<y2 , "{} {}"%(y1,y2) | |
xml_obj = parseString(dicttoxml.dicttoxml(OrderedDict(obj), attr_type=False, custom_root='object')) | |
x = dom.importNode(xml_obj.childNodes[0], True) | |
dom.childNodes[0].appendChild(x) | |
with open(AnnotationFile, "w") as anno: | |
print(dom.toprettyxml(), file=anno) | |
# 分数据集 | |
total = len(filename_jar) | |
random.shuffle(filename_jar) | |
sets_counter = 0 | |
for (set_name, set_scale) in sets.items(): | |
with open(os.path.join(BASE_DIR,"ImageSets","Main", set_name + ".txt"), 'w') as st: | |
tot = int(total * set_scale) | |
for ith in range(sets_counter, sets_counter + tot): | |
print("{}_{}".format("rgb", filename_jar[ith]), file=st) | |
sets_counter += tot | |
with open(os.path.join(BASE_DIR,"ImageSets","Main","trainval.txt"), 'w') as train_val: | |
for set_name in ["train","val"]: | |
for line in open(os.path.join(BASE_DIR,"ImageSets","Main", set_name + ".txt"), 'r') : | |
print(line,end="",file=train_val) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment