Skip to content

Instantly share code, notes, and snippets.

@hanxiao
Created July 4, 2018 11:05
Show Gist options
  • Save hanxiao/39d0ae9ca3a9a3af0f5bd69bb0581d6e to your computer and use it in GitHub Desktop.
Save hanxiao/39d0ae9ca3a9a3af0f5bd69bb0581d6e to your computer and use it in GitHub Desktop.
reject sampling with tf dataset api
(tf.data.Dataset.apply(tf.contrib.data.rejection_resample(class_func=lambda x: tf.py_func(self.get_class_label, [x],
[tf.int64],
name='class_label_fn'),
target_dist=tf.constant(self.uniform_dist, dtype=tf.float32),
initial_dist=tf.constant(self.init_dist, dtype=tf.float32)))
.map(lambda _, x: x)
.shuffle(buffer_size=1000)
.batch(self._args.batch_size)
.map(lambda x: tf.py_func(self._make_batch, [x], tf.string, name='train_mkbatch_fn'))
.prefetch(self._args.batch_size * self._args.prefetch_factor))
def get_class_label(self, sample):
return json.loads(sample)['label']
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment