Commit 1f498ed6 authored by Yuxin Wu's avatar Yuxin Wu

tf.TensorSpec == InputDesc

parent c667b1de
...@@ -7,10 +7,15 @@ import tensorflow as tf ...@@ -7,10 +7,15 @@ import tensorflow as tf
from ..models.regularize import regularize_cost_from_collection from ..models.regularize import regularize_cost_from_collection
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_tuple
from ..utils import logger from ..utils import logger
from ..utils.argtools import memoized_method from ..utils.argtools import memoized_method
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
if get_tf_version_tuple() >= (1, 7):
from tensorflow.python.framework.tensor_spec import TensorSpec
__all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase'] __all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
...@@ -71,7 +76,7 @@ class InputDesc( ...@@ -71,7 +76,7 @@ class InputDesc(
self._cached_placeholder[graph] = placeholder self._cached_placeholder[graph] = placeholder
@staticmethod @staticmethod
def from_placeholder(placeholder): def _from_placeholder(placeholder):
name = placeholder.op.name name = placeholder.op.name
if name.endswith('_1') or name.endswith('_2'): if name.endswith('_1') or name.endswith('_2'):
logger.error("Creating InputDesc from a placeholder named {}.".format(name)) logger.error("Creating InputDesc from a placeholder named {}.".format(name))
...@@ -83,6 +88,11 @@ class InputDesc( ...@@ -83,6 +88,11 @@ class InputDesc(
ret._register_cached_placeholder(placeholder) ret._register_cached_placeholder(placeholder)
return ret return ret
@staticmethod
def _from_tensor_spec(spec):
assert spec.name is not None, "TensorSpec should have a name!"
return InputDesc(spec.dtype, tuple(spec.shape.as_list()), spec.name)
class ModelDescBase(object): class ModelDescBase(object):
""" """
...@@ -106,9 +116,14 @@ class ModelDescBase(object): ...@@ -106,9 +116,14 @@ class ModelDescBase(object):
except NotImplementedError: except NotImplementedError:
with tf.Graph().as_default() as G: # create these placeholder in a temporary graph with tf.Graph().as_default() as G: # create these placeholder in a temporary graph
inputs = self.inputs() inputs = self.inputs()
for p in inputs: if isinstance(inputs[0], tf.Tensor):
assert p.graph == G, "Placeholders returned by inputs() should be created inside inputs()!" for p in inputs:
return [InputDesc.from_placeholder(p) for p in inputs] assert p.graph == G, "Placeholders returned by inputs() should be created inside inputs()!"
return [InputDesc._from_placeholder(p) for p in inputs]
else:
for p in inputs:
assert isinstance(p, TensorSpec), type(p)
return [InputDesc._from_tensor_spec(p) for p in inputs]
@property @property
def input_names(self): def input_names(self):
...@@ -123,16 +138,17 @@ class ModelDescBase(object): ...@@ -123,16 +138,17 @@ class ModelDescBase(object):
def inputs(self): def inputs(self):
""" """
__Create__ and returns a list of placeholders. Returns a list of :class:`tf.TensorSpec` or placeholders.
A subclass is expected to implement this method. A subclass is expected to implement this method.
The placeholders __have to__ be created inside this method. If returning placeholders,
Don't return placeholders created in other methods. the placeholders __have to__ be created inside this method.
Don't return placeholders created in other places.
Also, you should never call this method by yourself. Also, you should never call this method by yourself.
Returns: Returns:
a list of `tf.placeholder`, to be converted to :class:`InputDesc`. list[tf.placeholder] or list[tf.TensorSpec], to be converted to :class:`InputDesc`.
""" """
raise NotImplementedError() raise NotImplementedError()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment