Last active
June 23, 2019 11:57
-
-
Save gaphex/8641a3cafe8870cdcbb7f6a201c16e40 to your computer and use it in GitHub Desktop.
input fn with generator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def build_input_fn(container): | |
| def gen(): | |
| while True: | |
| try: | |
| yield build_feed_dict(container.get()) | |
| except StopIteration: | |
| yield build_feed_dict(container.get()) | |
| def input_fn(): | |
| return tf.data.Dataset.from_generator( | |
| gen, | |
| output_types={iname: tf.int32 for iname in INPUT_NAMES}, | |
| output_shapes={iname: (None, None) for iname in INPUT_NAMES}) | |
| return input_fn | |
| class DataContainer: | |
| def __init__(self): | |
| self._texts = None | |
| def set(self, texts): | |
| if type(texts) is str: | |
| texts = [texts] | |
| self._texts = texts | |
| def get(self): | |
| return self._texts |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment