Skip to content

Instantly share code, notes, and snippets.

@DEKHTIARJonathan
Created January 11, 2018 14:47
Show Gist options
  • Save DEKHTIARJonathan/a213b4b3fba1619225cd9b4e5ec81926 to your computer and use it in GitHub Desktop.
Save DEKHTIARJonathan/a213b4b3fba1619225cd9b4e5ec81926 to your computer and use it in GitHub Desktop.
Get Placeholders in Graph
from tensorflow.python.framework import ops
def get_placeholders(graph):
"""Get placeholders of a graph.
For example:
```python
a = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='a')
a = tf.placeholder(dtype=tf.int32, shape=[3, 2], name='b')
# would give [<tf.Tensor 'a:0' shape=(2, 2) dtype=float32>,
# <tf.Tensor 'b:0' shape=(3, 2) dtype=int32>]
tf.contrib.framework.get_placeholders(tf.get_default_graph())
```
Args:
graph: A tf.Graph.
Returns:
A list contains all placeholders of given graph.
Raises:
TypeError: If `graph` is not a tensorflow graph.
"""
if not isinstance(graph, ops.Graph):
raise TypeError("Input graph needs to be a Graph: %s" % graph)
# For each placeholder() call, there is a corresponding
# operation of type 'Placeholder' registered to the graph.
# The return value (a Tensor) of placeholder() is the
# first output of this operation in fact.
operations = graph.get_operations()
result = [i.outputs[0] for i in operations if i.type == "Placeholder"]
return result
get_placeholders(tf.get_default_graph())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment