Skip to content

Instantly share code, notes, and snippets.

@Pangoraw
Last active October 17, 2022 11:40
Show Gist options
  • Save Pangoraw/136318c360578f4bb307e399f3cbc22d to your computer and use it in GitHub Desktop.
Save Pangoraw/136318c360578f4bb307e399f3cbc22d to your computer and use it in GitHub Desktop.
COCW to COCO converter
import os
from pathlib import Path
import json
import argparse
from PIL import Image
import numpy as np
# import detectron2
# from detectron2.structures import BoxMode
import torchvision.transforms as T
from torch.hub import download_url_to_file
from tqdm import tqdm
Image.MAX_IMAGE_PIXELS = 933120000
def load_label_file(path):
img = np.asarray(Image.open(path))
return img.sum(-1)
class Dataset:
def __init__(self, base_folder, split):
self.base_folder = base_folder
assert split in ["train", "test"]
self.input_image_folder = self.base_folder / split
self.output_image_folder = self.base_folder / f"{split}_cropped"
self.output_image_folder.mkdir(exist_ok=True)
assert self.input_image_folder.exists()
self.input_files_file = self.base_folder / f"./{split}_list.txt"
with open(self.input_files_file) as f:
self.input_files = [l.strip() for l in f.readlines()]
labels_folder = self.base_folder / "labels"
labels_folder.mkdir(exist_ok=True)
self.label_files = {
path.name: load_label_file(self.base_folder / "labels" / path)
for path in map(Path, os.listdir(self.base_folder / "labels"))
}
self.crop_size = 200
self.cropper = T.CenterCrop((self.crop_size, self.crop_size))
self.id_counter = 0
self.annot_id_counter = 0
self.output_json_file = self.base_folder / f"{split}.json"
def process_all(self):
dicts = [
self.process(self.input_image_folder / p) for p in tqdm(self.input_files)
]
output = dict(
info=dict(
description="cowc dataset",
url="",
version="1.0",
year="2022",
contributor="Paul Berg",
date_created="13/11/2022",
),
images=[
dict(
id=d["image_id"],
file_name=d["file_name"],
height=d["height"],
width=d["width"],
)
for d in dicts
],
categories=[dict(id=1, name="car")],
annotations=[ann for d in dicts for ann in d["annotations"]],
)
with open(self.output_json_file, "w") as f:
json.dump(output, f)
def process(self, path: Path) -> dict:
img = Image.open(path) # 3x256x256
img = self.cropper(img) # 3xcrop_sizexcrop_size
img.save(self.output_image_folder / path.name)
img_name = path.name
data = img_name.split(".")
if len(data) == 7:
cls, _, label_file, x, y, r, _ = data
elif len(data) == 6:
cls, label_file, x, y, r, _ = data
else:
raise NotImplementedError(img_name)
x, y, r = float(x), float(y), float(r)
if label_file.startswith("Potsdam"):
label_file = label_file.replace("Potsdam", "top_potsdam")
else:
_, label_file = label_file.split("_")
label_file += "_Annotated_Cars.png"
x, y = int(x), int(y)
offset = self.crop_size // 2
if label_file not in self.label_files:
for k in list(self.label_files.keys()):
del self.label_files[k]
label_file_path = self.base_folder / "labels" / label_file
base_folder_name = self.base_folder.name
download_url_to_file(
f"https://gdo152.llnl.gov/cowc/download/cowc/datasets/ground_truth_sets/{base_folder_name}/"
+ label_file,
label_file_path,
)
self.label_files[label_file] = load_label_file(label_file_path)
label_img = self.label_files[label_file]
y, x = np.nonzero(label_img[y - offset : y + offset, x - offset : x + offset])
car_size = 300 // 15 # 3m, 15cm/pixels
x0 = x - car_size // 2
x1 = x + car_size // 2
y0 = y - car_size // 2
y1 = y + car_size // 2
h, w = x1 - x0, y1 - y0
id = self.id_counter
self.id_counter += 1
annot_id = self.annot_id_counter
self.annot_id_counter += x0.shape[0]
boxes = zip(x0, y0, w, h)
return dict(
image_id=id,
file_name=path.name,
height=self.crop_size,
width=self.crop_size,
annotations=[
dict(
id=annot_id + i,
image_id=id,
bbox=list(map(int, bbox)),
iscrowd=0,
area=car_size ** 2,
category_id=1,
)
for i, bbox in enumerate(boxes)
],
)
def main():
parser = argparse.ArgumentParser("COCW to COCO")
parser.add_argument("base_folder", type=Path)
parser.add_argument("split", type=str, default="train")
args = parser.parse_args()
data = Dataset(base_folder=args.base_folder,
split=args.split)
data.process_all()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment