Commit 1f498ed6 authored by Yuxin Wu's avatar Yuxin Wu

tf.TensorSpec == InputDesc

parent c667b1de
......@@ -7,10 +7,15 @@ import tensorflow as tf
from ..models.regularize import regularize_cost_from_collection
from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_tuple
from ..utils import logger
from ..utils.argtools import memoized_method
from ..utils.develop import log_deprecated
if get_tf_version_tuple() >= (1, 7):
from tensorflow.python.framework.tensor_spec import TensorSpec
__all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
......@@ -71,7 +76,7 @@ class InputDesc(
self._cached_placeholder[graph] = placeholder
@staticmethod
def from_placeholder(placeholder):
def _from_placeholder(placeholder):
name = placeholder.op.name
if name.endswith('_1') or name.endswith('_2'):
logger.error("Creating InputDesc from a placeholder named {}.".format(name))
......@@ -83,6 +88,11 @@ class InputDesc(
ret._register_cached_placeholder(placeholder)
return ret
@staticmethod
def _from_tensor_spec(spec):
assert spec.name is not None, "TensorSpec should have a name!"
return InputDesc(spec.dtype, tuple(spec.shape.as_list()), spec.name)
class ModelDescBase(object):
"""
......@@ -106,9 +116,14 @@ class ModelDescBase(object):
except NotImplementedError:
with tf.Graph().as_default() as G: # create these placeholder in a temporary graph
inputs = self.inputs()
if isinstance(inputs[0], tf.Tensor):
for p in inputs:
assert p.graph == G, "Placeholders returned by inputs() should be created inside inputs()!"
return [InputDesc.from_placeholder(p) for p in inputs]
return [InputDesc._from_placeholder(p) for p in inputs]
else:
for p in inputs:
assert isinstance(p, TensorSpec), type(p)
return [InputDesc._from_tensor_spec(p) for p in inputs]
@property
def input_names(self):
......@@ -123,16 +138,17 @@ class ModelDescBase(object):
def inputs(self):
"""
__Create__ and returns a list of placeholders.
Returns a list of :class:`tf.TensorSpec` or placeholders.
A subclass is expected to implement this method.
The placeholders __have to__ be created inside this method.
Don't return placeholders created in other methods.
If returning placeholders,
the placeholders __have to__ be created inside this method.
Don't return placeholders created in other places.
Also, you should never call this method by yourself.
Returns:
a list of `tf.placeholder`, to be converted to :class:`InputDesc`.
list[tf.placeholder] or list[tf.TensorSpec], to be converted to :class:`InputDesc`.
"""
raise NotImplementedError()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment