Created
April 3, 2018 10:41
-
-
Save walkerlala/82d978e68407e65158e8825cd470d7e1 to your computer and use it in GitHub Desktop.
A patch for training deeplabv3 on the ADE20K dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/research/deeplab/datasets/build_ade20k_data.py b/research/deeplab/datasets/build_ade20k_data.py | |
new file mode 100644 | |
index 0000000..7dcf79c | |
--- /dev/null | |
+++ b/research/deeplab/datasets/build_ade20k_data.py | |
@@ -0,0 +1,116 @@ | |
+# Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
+# | |
+# Licensed under the Apache License, Version 2.0 (the "License"); | |
+# you may not use this file except in compliance with the License. | |
+# You may obtain a copy of the License at | |
+# | |
+# http://www.apache.org/licenses/LICENSE-2.0 | |
+# | |
+# Unless required by applicable law or agreed to in writing, software | |
+# distributed under the License is distributed on an "AS IS" BASIS, | |
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+# See the License for the specific language governing permissions and | |
+# limitations under the License. | |
+# ============================================================================== | |
+ | |
+import glob | |
+import math | |
+import os | |
+import random | |
+import string | |
+import sys | |
+from PIL import Image | |
+import build_data | |
+import tensorflow as tf | |
+ | |
+FLAGS = tf.app.flags.FLAGS | |
+flags = tf.app.flags | |
+ | |
+tf.app.flags.DEFINE_string( | |
+ 'train_image_folder', | |
+ './ADEChallengeData2016/images/training', | |
+ 'Folder containing trainng images') | |
+tf.app.flags.DEFINE_string( | |
+ 'train_image_label_folder', | |
+ './ADEChallengeData2016/annotations/training', | |
+ 'Folder containing annotations for trainng images') | |
+ | |
+tf.app.flags.DEFINE_string( | |
+ 'val_image_folder', | |
+ './ADE20K_2016_07_26/images/validation', | |
+ 'Folder containing validation images') | |
+ | |
+tf.app.flags.DEFINE_string( | |
+ 'val_image_label_folder', | |
+ './ADE20K_2016_07_26/annotations/validation', | |
+ 'Folder containing annotations for validation') | |
+ | |
+tf.app.flags.DEFINE_string( | |
+ 'output_dir', './tfrecord', | |
+ 'Path to save converted SSTable of Tensorflow example') | |
+ | |
+_NUM_SHARDS = 4 | |
+ | |
+def _convert_dataset(dataset_split, dataset_dir, dataset_label_dir): | |
+ """ Convert the ADE20k dataset into into tfrecord format (SSTable). | |
+ | |
+ Args: | |
+ dataset_split: dataset split (e.g., train, val) | |
+ dataset_dir: dir in which the dataset locates | |
+ dataset_label_dir: dir in which the annotations locates | |
+ | |
+ Raises: | |
+ RuntimeError: If loaded image and label have different shape. | |
+ """ | |
+ | |
+ img_names = glob.glob(os.path.join(dataset_dir, '*.jpg')) | |
+ random.shuffle(img_names) | |
+ seg_names = [] | |
+ for f in img_names: | |
+ # get the filename without the extension | |
+ basename = os.path.basename(f).split(".")[0] | |
+ # cover its corresponding *_seg.png | |
+ seg = os.path.join(dataset_label_dir, basename+'.png') | |
+ seg_names.append(seg) | |
+ | |
+ num_images = len(img_names) | |
+ num_per_shard = int(math.ceil(num_images) / float(_NUM_SHARDS)) | |
+ | |
+ image_reader = build_data.ImageReader('jpeg', channels=3) | |
+ label_reader = build_data.ImageReader('png', channels=1) | |
+ | |
+ for shard_id in range(_NUM_SHARDS): | |
+ output_filename = os.path.join( | |
+ FLAGS.output_dir, | |
+ '%s-%05d-of-%05d.tfrecord' % (dataset_split, shard_id, _NUM_SHARDS)) | |
+ with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: | |
+ start_idx = shard_id * num_per_shard | |
+ end_idx = min((shard_id + 1) * num_per_shard, num_images) | |
+ for i in range(start_idx, end_idx): | |
+ sys.stdout.write('\r>> Converting image %d/%d shard %d' % ( | |
+ i + 1, num_images, shard_id)) | |
+ sys.stdout.flush() | |
+ # Read the image. | |
+ image_filename = img_names[i] | |
+ image_data = tf.gfile.FastGFile(image_filename, 'r').read() | |
+ height, width = image_reader.read_image_dims(image_data) | |
+ # Read the semantic segmentation annotation. | |
+ seg_filename = seg_names[i] | |
+ seg_data = tf.gfile.FastGFile(seg_filename, 'r').read() | |
+ seg_height, seg_width = label_reader.read_image_dims(seg_data) | |
+ if height != seg_height or width != seg_width: | |
+ raise RuntimeError('Shape mismatched between image and label.') | |
+ # Convert to tf example. | |
+ example = build_data.image_seg_to_tfexample( | |
+ image_data, img_names[i], height, width, seg_data) | |
+ tfrecord_writer.write(example.SerializeToString()) | |
+ sys.stdout.write('\n') | |
+ sys.stdout.flush() | |
+ | |
+def main(unused_argv): | |
+ tf.gfile.MakeDirs(FLAGS.output_dir) | |
+ _convert_dataset('train', FLAGS.train_image_folder, FLAGS.train_image_label_folder) | |
+ _convert_dataset('val', FLAGS.val_image_folder, FLAGS.val_image_label_folder) | |
+ | |
+if __name__ == '__main__': | |
+ tf.app.run() | |
diff --git a/research/deeplab/datasets/build_ade20k_eval_data.py b/research/deeplab/datasets/build_ade20k_eval_data.py | |
new file mode 100644 | |
index 0000000..a1acead | |
--- /dev/null | |
+++ b/research/deeplab/datasets/build_ade20k_eval_data.py | |
@@ -0,0 +1,104 @@ | |
+# Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
+# | |
+# Licensed under the Apache License, Version 2.0 (the "License"); | |
+# you may not use this file except in compliance with the License. | |
+# You may obtain a copy of the License at | |
+# | |
+# http://www.apache.org/licenses/LICENSE-2.0 | |
+# | |
+# Unless required by applicable law or agreed to in writing, software | |
+# distributed under the License is distributed on an "AS IS" BASIS, | |
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+# See the License for the specific language governing permissions and | |
+# limitations under the License. | |
+# ============================================================================== | |
+ | |
+import glob | |
+import math | |
+import os | |
+import random | |
+import string | |
+import sys | |
+from PIL import Image | |
+import build_data | |
+import tensorflow as tf | |
+ | |
+FLAGS = tf.app.flags.FLAGS | |
+flags = tf.app.flags | |
+ | |
+tf.app.flags.DEFINE_string( | |
+ 'eval_image_folder', | |
+ './ADE20K_2016_07_26/images/training', | |
+ 'Folder containing trainng images and segmentation annotations') | |
+ | |
+tf.app.flags.DEFINE_string( | |
+ 'output_dir', './tfrecord', | |
+ 'Path to save converted SSTable of Tensorflow example') | |
+ | |
+_NUM_SHARDS = 1 | |
+ | |
+def _convert_dataset(dataset_split, dataset_dir): | |
+ """ Convert the ADE20k dataset into into tfrecord format (SSTable). | |
+ | |
+ Args: | |
+ dataset_split: dataset split (e.g., train, test) | |
+ dataset_dir: dir in which the dataset locates | |
+ | |
+ Raises: | |
+ RuntimeError: If loaded image and label have different shape. | |
+ """ | |
+ | |
+ img_names = [] | |
+ seg_names = [] | |
+ for i in string.ascii_lowercase: | |
+ dir_name = os.path.join(dataset_dir, i) | |
+ img_names.extend(glob.glob(os.path.join(dir_name, '*/*.jpg'))) | |
+ img_names.extend(glob.glob(os.path.join(dir_name, '*/*.jpeg'))) | |
+ random.shuffle(img_names) | |
+ for f in img_names: | |
+ # get the filename without the extension | |
+ basename = os.path.basename(f).split(".")[0] | |
+ # cover its corresponding *_seg.png | |
+ seg = os.path.join(os.path.dirname(f), basename+'_seg.png') | |
+ seg_names.append(seg) | |
+ | |
+ num_images = len(img_names) | |
+ num_per_shard = int(math.ceil(num_images) / float(_NUM_SHARDS)) | |
+ | |
+ image_reader = build_data.ImageReader('jpeg', channels=3) | |
+ label_reader = build_data.ImageReader('png', channels=1) | |
+ | |
+ for shard_id in range(_NUM_SHARDS): | |
+ output_filename = os.path.join( | |
+ FLAGS.output_dir, | |
+ '%s-%05d-of-%05d.tfrecord' % (dataset_split, shard_id, _NUM_SHARDS)) | |
+ with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: | |
+ start_idx = shard_id * num_per_shard | |
+ end_idx = min((shard_id + 1) * num_per_shard, num_images) | |
+ for i in range(start_idx, end_idx): | |
+ sys.stdout.write('\r>> Converting image %d/%d shard %d' % ( | |
+ i + 1, num_images, shard_id)) | |
+ sys.stdout.flush() | |
+ # Read the image. | |
+ image_filename = img_names[i] | |
+ image_data = tf.gfile.FastGFile(image_filename, 'r').read() | |
+ height, width = image_reader.read_image_dims(image_data) | |
+ # Read the semantic segmentation annotation. | |
+ seg_filename = seg_names[i] | |
+ seg_data = tf.gfile.FastGFile(seg_filename, 'r').read() | |
+ seg_height, seg_width = label_reader.read_image_dims(seg_data) | |
+ if height != seg_height or width != seg_width: | |
+ raise RuntimeError('Shape mismatched between image and label.') | |
+ # Convert to tf example. | |
+ example = build_data.image_seg_to_tfexample( | |
+ image_data, img_names[i], height, width, seg_data) | |
+ tfrecord_writer.write(example.SerializeToString()) | |
+ sys.stdout.write('\n') | |
+ sys.stdout.flush() | |
+ | |
+def main(unused_argv): | |
+ tf.gfile.MakeDirs(FLAGS.output_dir) | |
+ _convert_dataset('eval', FLAGS.eval_image_folder) | |
+ | |
+if __name__ == '__main__': | |
+ tf.app.run() | |
diff --git a/research/deeplab/datasets/segmentation_dataset.py b/research/deeplab/datasets/segmentation_dataset.py | |
index a777252..cedeb54 100644 | |
--- a/research/deeplab/datasets/segmentation_dataset.py | |
+++ b/research/deeplab/datasets/segmentation_dataset.py | |
@@ -85,10 +85,24 @@ _PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor( | |
ignore_label=255, | |
) | |
+# These number (i.e., 'train'/'test') seems to have to be hard coded | |
+# You are required to figure it out for your training/testing example. | |
+# Is there a way to automatically figure it out ? | |
+_ADE20K_INFORMATION = DatasetDescriptor( | |
+ splits_to_sizes = { | |
+ 'train': 20210, # num of samples in images/training | |
+ 'val': 2000, # num of samples in images/validation | |
+ 'eval': 2, | |
+ }, | |
+ num_classes=150, | |
+ ignore_label=255, | |
+) | |
+ | |
_DATASETS_INFORMATION = { | |
'cityscapes': _CITYSCAPES_INFORMATION, | |
'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION, | |
+ 'ade20k': _ADE20K_INFORMATION, | |
} | |
# Default file pattern of TFRecord of TensorFlow Example. | |
diff --git a/research/deeplab/train_ade20k.sh b/research/deeplab/train_ade20k.sh | |
new file mode 100755 | |
index 0000000..2005409 | |
--- /dev/null | |
+++ b/research/deeplab/train_ade20k.sh | |
@@ -0,0 +1,38 @@ | |
+#!/bin/bash | |
+ | |
+# NOTE To use the ADE20K dataset, do | |
+# 1. set `--initialize_last_layer=False` so that we can have @num_classes==150 | |
+# 2. set proper values for `--min_resize_value` (e.g., 350) and | |
+# `--max_resize_value` (e.g., 500) and let `--resize_factor equals to | |
+# `--output_stride` | |
+# 3. set @exclude_list in utils/train_utils.py to only include _LOGITS_SCOPE_NAME | |
+# 4. modify the setting of _ADE20K_INFORMATION in datasets/segmentation_dataset.py | |
+# | |
+# Note also that, as stated in the doc, if using a fine-tuned pretrained model, | |
+# one can set `fine_tune_batch_norm=False` and use a smaller batch size (you | |
+# will need a batch size >= 16 if you are going to fine-tune the BN parameters) | |
+ | |
+ | |
+python train.py \ | |
+ --logtostderr \ | |
+ --training_number_of_steps=130000 \ | |
+ --train_split="train" \ | |
+ --model_variant="xception_65" \ | |
+ --astrous_rates=6 \ | |
+ --astrous_rates=12 \ | |
+ --astrous_rates=18 \ | |
+ --output_stride=16 \ | |
+ --decoder_output_stride=4 \ | |
+ --train_crop_size=513 \ | |
+ --train_crop_size=513 \ | |
+ --train_batch_size=4 \ | |
+ --min_resize_value=350 \ | |
+ --max_resize_value=500 \ | |
+ --resize_factor=16 \ | |
+ --fine_tune_batch_norm=False \ | |
+ --dataset="ade20k" \ | |
+ --initialize_last_layer=False \ | |
+ --tf_initial_checkpoint=/home/yubin/project/models/research/deeplab/deeplabv3_pascal_train_aug/model.ckpt \ | |
+ --train_logdir=/home/yubin/project/models/research/deeplab/datasets/exp/ade20k/train \ | |
+ --dataset_dir=/home/yubin/project/models/research/deeplab/datasets/ADE20K/tfrecord/ | |
+ | |
diff --git a/research/deeplab/utils/train_utils.py b/research/deeplab/utils/train_utils.py | |
index e2562bf..9bc9cbc 100644 | |
--- a/research/deeplab/utils/train_utils.py | |
+++ b/research/deeplab/utils/train_utils.py | |
@@ -99,7 +99,8 @@ def get_model_init_fn(train_logdir, | |
tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint) | |
# Variables that will not be restored. | |
- exclude_list = ['global_step'] | |
+ #exclude_list = ['global_step'] | |
+ exclude_list = ['logits'] # model._LOGITS_SCOPE_NAME | |
if not initialize_last_layer: | |
exclude_list.extend(last_layers) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment