Last active
December 15, 2020 12:03
-
-
Save alexwal/9fca4efb936265d62e389fba5bacd4b3 to your computer and use it in GitHub Desktop.
Example of how to handle errors in a tf.data.Dataset input pipeline
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def create_bad_dataset(create_batches=True): | |
dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4., 8., 16.]) | |
# Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError. | |
if create_batches: | |
# Demonstrates that error handling works with map_and_batch | |
dataset = dataset.apply(tf.contrib.data.map_and_batch( | |
map_func=lambda x: tf.check_numerics(1. / x, 'error'), batch_size=2)) | |
else: | |
dataset = dataset.map(lambda x: tf.check_numerics(1. / x, 'error')) | |
return dataset | |
def create_bad_dataset_with_filter(create_batches=True): | |
# Should never error because 0 are filtered and 1 / 0 never computed. | |
dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4., 8., 16.]) | |
dataset = dataset.prefetch(3) | |
# filtering | |
dataset = dataset.filter(lambda x: tf.not_equal(x, 0.)) | |
# Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError. | |
if create_batches: | |
# Demonstrates that error handling works with map_and_batch | |
dataset = dataset.apply(tf.contrib.data.map_and_batch( | |
map_func=lambda x: tf.check_numerics(1. / x, 'error'), batch_size=4, drop_remainder=False)) | |
else: | |
dataset = dataset.map(lambda x: tf.check_numerics(1. / x, 'error')) | |
return dataset | |
def test_without_error_handling(): | |
dataset = create_bad_dataset() | |
iterator = dataset.make_one_shot_iterator() | |
next_element = iterator.get_next() | |
with tf.Session() as sess: | |
while True: | |
try: | |
x = sess.run(next_element) | |
print(x) | |
except tf.errors.OutOfRangeError: | |
print('break from loop') | |
break | |
def test_catch_error_in_run_loop(): | |
dataset = create_bad_dataset() | |
iterator = dataset.make_one_shot_iterator() | |
next_element = iterator.get_next() | |
with tf.Session() as sess: | |
while True: | |
try: | |
x = sess.run(next_element) | |
print(x) | |
except tf.errors.OutOfRangeError: | |
print('break from loop') | |
break | |
except tf.errors.InvalidArgumentError: | |
print('Error: InvalidArgumentError') | |
def test_ignore_errors(): | |
dataset = create_bad_dataset() | |
# Using `ignore_errors()` will drop the element that causes an error. | |
dataset = dataset.apply(tf.contrib.data.ignore_errors()) # ==> { 1., 0.5, 0.25, 0.125, 0.0625 } | |
iterator = dataset.make_one_shot_iterator() | |
next_element = iterator.get_next() | |
with tf.Session() as sess: | |
while True: | |
try: | |
x = sess.run(next_element) | |
print(x) | |
except tf.errors.OutOfRangeError: | |
print('break from loop') | |
break | |
def run(): | |
print('\n--> Testing by catching errors in run loop...') | |
test_catch_error_in_run_loop() | |
print('\n--> Testing by catching errors with tf.contrib.data.ignore_errors()...') | |
test_ignore_errors() | |
# Uncomment below to run with uncaught exception | |
print('\n--> Testing without error handling...') | |
test_without_error_handling() | |
if __name__ == '__main__': | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment