Skip to content

Instantly share code, notes, and snippets.

@zakirangwala
Created December 1, 2020 06:19
Show Gist options
  • Save zakirangwala/fce1a1e76991f6d677e8d025da51ac6a to your computer and use it in GitHub Desktop.
Save zakirangwala/fce1a1e76991f6d677e8d025da51ac6a to your computer and use it in GitHub Desktop.
Pipeline Configuration - Mask Detection
# Configuration
def config():
CONFIG_PATH = MODEL_PATH+'/'+CUSTOM_MODEL_NAME+'/pipeline.config'
config = config_util.get_configs_from_pipeline_file(CONFIG_PATH)
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.io.gfile.GFile(CONFIG_PATH, "r") as f:
proto_str = f.read()
text_format.Merge(proto_str, pipeline_config)
# Modify pipeline
pipeline_config.model.ssd.num_classes = 2
pipeline_config.train_config.batch_size = 4
pipeline_config.train_config.fine_tune_checkpoint = PRETRAINED_MODEL_PATH + \
'/ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8/checkpoint/ckpt-0'
pipeline_config.train_config.fine_tune_checkpoint_type = "detection"
pipeline_config.train_input_reader.label_map_path = ANNOTATION_PATH + '/label_map.pbtxt'
pipeline_config.train_input_reader.tf_record_input_reader.input_path[:] = [
ANNOTATION_PATH + '/train.record']
pipeline_config.eval_input_reader[0].label_map_path = ANNOTATION_PATH + \
'/label_map.pbtxt'
pipeline_config.eval_input_reader[0].tf_record_input_reader.input_path[:] = [
ANNOTATION_PATH + '/test.record']
config_text = text_format.MessageToString(pipeline_config)
with tf.io.gfile.GFile(CONFIG_PATH, "wb") as f:
f.write(config_text)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment