Skip to content

Instantly share code, notes, and snippets.

@nagachika
Created January 17, 2018 04:39
Show Gist options
  • Save nagachika/c35b7f65d2d5fa8b95ea527e9dd1cd91 to your computer and use it in GitHub Desktop.
Save nagachika/c35b7f65d2d5fa8b95ea527e9dd1cd91 to your computer and use it in GitHub Desktop.
Pitfall of tf.dataset.Dataset + tf.lookup_ops
# TFRecord のファイルを tf.dataset で読み込むパイプラインを組む
# これは NG
labels = ["a", "b", "c"]
def input_fn():
def _parse_example(example):
features = tf.parse_single_example(example, features=tf_features_dict)
label_str = features.pop("label")
return features, tf.contrib.lookup.index_table_from_tensor(labels).lookup(label_str)
dataset = tf.dataset.TFRecordDataset(filename).map(_parse_example).shuffle(batch_size*10).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
# これは OK
labels = ["a", "b", "c"]
def input_fn()
def _lookup_label(label_str):
return tf.contrib.lookup.index_table_from_tensor(labels).lookup(label_str)
def _parse_example(example):
return tf.parse_single_example(example, features=tf_features_dict)
dataset = tf.dataset.TFRecordDataset(filename).map(_parse_example).shuffle(batch_size*10).batch(batch_size)
features = dataset.make_one_shot_iterator().get_next()
label_str = features.pop("label")
return features, _lookup_label(label_str)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment