Last active
July 28, 2021 04:57
-
-
Save Taehun/20d957470a3d1e205c08f924542fbc65 to your computer and use it in GitHub Desktop.
Add AIMMO dataset to DDRNet.pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
From 53ed3d93263ae734ea2b09607b2769454924ad1e Mon Sep 17 00:00:00 2001 | |
From: Taehun Kim <[email protected]> | |
Date: Wed, 28 Jul 2021 13:53:33 +0900 | |
Subject: [PATCH] Add AIMMO dataset | |
--- | |
lib/datasets/__init__.py | 11 +++--- | |
lib/datasets/aimmo.py | 78 ++++++++++++++++++++++++++++++++++++++++ | |
2 files changed, 83 insertions(+), 6 deletions(-) | |
create mode 100644 lib/datasets/aimmo.py | |
diff --git a/lib/datasets/__init__.py b/lib/datasets/__init__.py | |
index 836d3b3..6df22c7 100644 | |
--- a/lib/datasets/__init__.py | |
+++ b/lib/datasets/__init__.py | |
@@ -4,13 +4,12 @@ | |
# Written by Ke Sun ([email protected]) | |
# ------------------------------------------------------------------------------ | |
-from __future__ import absolute_import | |
-from __future__ import division | |
-from __future__ import print_function | |
+from __future__ import absolute_import, division, print_function | |
+from .ade20k import ADE20K as ade20k | |
+from .aimmo import AIMMO as aimmo | |
from .cityscapes import Cityscapes as cityscapes | |
+from .cocostuff import COCOStuff as cocostuff | |
from .lip import LIP as lip | |
-from .pascal_ctx import PASCALContext as pascal_ctx | |
-from .ade20k import ADE20K as ade20k | |
from .map import MAP as map | |
-from .cocostuff import COCOStuff as cocostuff | |
\ No newline at end of file | |
+from .pascal_ctx import PASCALContext as pascal_ctx | |
diff --git a/lib/datasets/aimmo.py b/lib/datasets/aimmo.py | |
new file mode 100644 | |
index 0000000..8b61e6e | |
--- /dev/null | |
+++ b/lib/datasets/aimmo.py | |
@@ -0,0 +1,78 @@ | |
+# ------------------------------------------------------------------------------ | |
+# Copyright (c) AIMMO | |
+# Licensed under the MIT License. | |
+# Written by Taehun Kim ([email protected]) | |
+# ------------------------------------------------------------------------------ | |
+ | |
+import os | |
+ | |
+import cv2 | |
+import numpy as np | |
+import torch | |
+from PIL import Image | |
+from torch.nn import functional as F | |
+ | |
+from .base_dataset import BaseDataset | |
+ | |
+ | |
+class AIMMO(BaseDataset): | |
+ def __init__( | |
+ self, | |
+ root, | |
+ list_path, | |
+ num_samples=None, | |
+ num_classes=10, | |
+ multi_scale=True, | |
+ flip=True, | |
+ ignore_label=-1, | |
+ base_size=1024, | |
+ crop_size=(1024, 1024), | |
+ downsample_rate=1, | |
+ scale_factor=16, | |
+ mean=[0.485, 0.456, 0.406], | |
+ std=[0.229, 0.224, 0.225], | |
+ ): | |
+ | |
+ super(AIMMO, self).__init__(ignore_label, base_size, crop_size, downsample_rate, scale_factor, mean, std) | |
+ | |
+ self.root = root | |
+ self.num_classes = num_classes | |
+ self.list_path = list_path | |
+ self.class_weights = None | |
+ | |
+ self.multi_scale = multi_scale | |
+ self.flip = flip | |
+ self.img_list = [line.strip().split() for line in open(root + list_path)] | |
+ | |
+ self.files = self.read_files() | |
+ if num_samples: | |
+ self.files = self.files[:num_samples] | |
+ | |
+ def read_files(self): | |
+ files = [] | |
+ for item in self.img_list: | |
+ image_path, label_path = item | |
+ name = os.path.splitext(os.path.basename(label_path))[0] | |
+ sample = {"img": image_path, "label": label_path, "name": name} | |
+ files.append(sample) | |
+ return files | |
+ | |
+ def resize_image(self, image, label, size): | |
+ image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR) | |
+ label = cv2.resize(label, size, interpolation=cv2.INTER_NEAREST) | |
+ return image, label | |
+ | |
+ def __getitem__(self, index): | |
+ item = self.files[index] | |
+ name = item["name"] | |
+ image_path = os.path.join(self.root, "aimmo", item["img"]) | |
+ label_path = os.path.join(self.root, "aimmo", item["label"]) | |
+ image = cv2.imread(image_path, cv2.IMREAD_COLOR) | |
+ label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) | |
+ | |
+ size = label.shape | |
+ | |
+ image, label = self.resize_image(image, label, self.crop_size) | |
+ image, label = self.gen_sample(image, label, self.multi_scale, self.flip) | |
+ | |
+ return image.copy(), label.copy(), np.array(size), name | |
-- | |
2.29.2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment