Commit 71c879bc authored by Yuxin Wu's avatar Yuxin Wu

move build_and_reuse_placeholder to input_source

parent 4dadc6f0
tensorpack.graph_builder package tensorpack.graph_builder package
================================ ================================
These are some useful functions if you need to write your own trainers.
Note that they may not be well maintained.
.. automodule:: tensorpack.graph_builder .. automodule:: tensorpack.graph_builder
:members: :members:
:undoc-members: :undoc-members:
......
...@@ -17,35 +17,6 @@ TensorSpec = backport_tensor_spec() ...@@ -17,35 +17,6 @@ TensorSpec = backport_tensor_spec()
__all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase'] __all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
def build_or_reuse_placeholder(tensor_spec):
"""
Build a tf.placeholder from the metadata in the given tensor spec, or return an existing one.
Args:
tensor_spec (tf.TensorSpec):
Returns:
tf.Tensor:
"""
g = tfv1.get_default_graph()
name = tensor_spec.name
try:
tensor = g.get_tensor_by_name(name + ':0')
assert "Placeholder" in tensor.op.type, "Tensor {} exists but is not a placeholder!".format(name)
assert tensor_spec.is_compatible_with(tensor), \
"Tensor {} exists but is not compatible with the signature!".format(tensor)
if tensor.shape == tensor_spec.shape:
# It might be desirable to use a placeholder of a different shape in some tower
# (e.g., a less specific shape)
return tensor
except KeyError:
pass
with tfv1.name_scope(None): # clear any name scope it might get called in
ret = tfv1.placeholder(
tensor_spec.dtype, shape=tensor_spec.shape, name=tensor_spec.name)
return ret
class InputDesc( class InputDesc(
namedtuple('InputDescTuple', ['type', 'shape', 'name'])): namedtuple('InputDescTuple', ['type', 'shape', 'name'])):
""" """
...@@ -65,7 +36,7 @@ class InputDesc( ...@@ -65,7 +36,7 @@ class InputDesc(
shape (tuple): shape (tuple):
name (str): name (str):
""" """
# TODO mark deprecated log_deprecated("InputDesc", "Use tf.TensorSpec instead!", "2020-03-01")
assert isinstance(type, tf.DType), type assert isinstance(type, tf.DType), type
return tf.TensorSpec(shape=shape, dtype=type, name=name) return tf.TensorSpec(shape=shape, dtype=type, name=name)
......
...@@ -13,7 +13,7 @@ from ..tfutils.varreplace import custom_getter_scope ...@@ -13,7 +13,7 @@ from ..tfutils.varreplace import custom_getter_scope
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once from ..utils.argtools import call_only_once
__all__ = ["LeastLoadedDeviceSetter"] __all__ = ["LeastLoadedDeviceSetter", "allreduce_grads", "aggregate_grads"]
""" """
......
...@@ -18,8 +18,7 @@ from ..tfutils.summary import add_moving_summary ...@@ -18,8 +18,7 @@ from ..tfutils.summary import add_moving_summary
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from ..utils.concurrency import ShareSessionThread from ..utils.concurrency import ShareSessionThread
from .input_source_base import InputSource from .input_source_base import InputSource, build_or_reuse_placeholder
from ..graph_builder.model_desc import build_or_reuse_placeholder
try: try:
from tensorflow.python.ops.data_flow_ops import StagingArea from tensorflow.python.ops.data_flow_ops import StagingArea
......
...@@ -12,11 +12,40 @@ from ..callbacks.base import CallbackFactory ...@@ -12,11 +12,40 @@ from ..callbacks.base import CallbackFactory
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once, memoized_method from ..utils.argtools import call_only_once, memoized_method
from ..graph_builder.model_desc import build_or_reuse_placeholder from ..compat import tfv1
__all__ = ['InputSource', 'remap_input_source'] __all__ = ['InputSource', 'remap_input_source']
def build_or_reuse_placeholder(tensor_spec):
"""
Build a tf.placeholder from the metadata in the given tensor spec, or return an existing one.
Args:
tensor_spec (tf.TensorSpec):
Returns:
tf.Tensor:
"""
g = tfv1.get_default_graph()
name = tensor_spec.name
try:
tensor = g.get_tensor_by_name(name + ':0')
assert "Placeholder" in tensor.op.type, "Tensor {} exists but is not a placeholder!".format(name)
assert tensor_spec.is_compatible_with(tensor), \
"Tensor {} exists but is not compatible with the signature!".format(tensor)
if tensor.shape == tensor_spec.shape:
# It might be desirable to use a placeholder of a different shape in some tower
# (e.g., a less specific shape)
return tensor
except KeyError:
pass
with tfv1.name_scope(None): # clear any name scope it might get called in
ret = tfv1.placeholder(
tensor_spec.dtype, shape=tensor_spec.shape, name=tensor_spec.name)
return ret
def get_tensors_inputs(placeholders, tensors, names): def get_tensors_inputs(placeholders, tensors, names):
""" """
Args: Args:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment