Commit 787be08e authored by Yuxin Wu's avatar Yuxin Wu

Deprecate warning for old modeldesc interface.

parent 3700a803
......@@ -9,7 +9,6 @@ from ..utils import logger
from ..utils.argtools import memoized
from ..utils.develop import log_deprecated
from ..tfutils.tower import get_current_tower_context
from ..input_source import InputSource
from ..models.regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
......@@ -95,10 +94,16 @@ class ModelDescBase(object):
def get_inputs_desc(self):
"""
Returns:
a list of :class:`InputDesc`.
A list of :class:`InputDesc`, which describes the inputs of this model.
The result is cached for each instance of :class:`ModelDescBase`.
"""
try:
return self._get_inputs()
ret = self._get_inputs()
log_deprecated(
"ModelDescBase._get_inputs() interface",
"Use inputs() instead!",
"2019-03-30")
return ret
except NotImplementedError:
with tf.Graph().as_default() as G: # create these placeholder in a temporary graph
inputs = self.inputs()
......@@ -106,6 +111,14 @@ class ModelDescBase(object):
assert p.graph == G, "Placeholders returned by inputs() should be created inside inputs()!"
return [InputDesc.from_placeholder(p) for p in inputs]
@property
def input_names(self):
"""
Returns:
[str]: the names of all the inputs.
"""
return [k.name for k in self.get_inputs_desc()]
def _get_inputs(self):
raise NotImplementedError()
......@@ -116,7 +129,8 @@ class ModelDescBase(object):
The placeholders __have to__ be created inside this method.
Don't return placeholders created in other methods.
Also, you should not call this method by yourself.
Also, you should never call this method by yourself.
Returns:
a list of `tf.placeholder`, to be converted to :class:`InputDesc`.
......@@ -128,7 +142,7 @@ class ModelDescBase(object):
Build the whole symbolic graph.
This is supposed to be part of the "tower function" when used with :class:`TowerTrainer`.
A subclass is expected to overwrite this method.
A subclass is expected to implement this method.
Args:
args ([tf.Tensor]): tensors that matches the list of inputs defined by ``inputs()``.
......@@ -138,24 +152,14 @@ class ModelDescBase(object):
may require it to return necessary information to build the trainer.
For example, `SingleCostTrainer` expect this method to return the cost tensor.
"""
if len(args) == 1:
arg = args[0]
if isinstance(arg, InputSource):
inputs = arg.get_input_tensors() # remove in the future?
log_deprecated("build_graph(InputSource)",
"Call with tensors in positional args instead.", "2018-03-31")
elif isinstance(arg, (list, tuple)):
inputs = arg
log_deprecated("build_graph([Tensor])", "Call with positional args instead.", "2018-03-31")
else:
inputs = [arg]
else:
inputs = args
assert len(inputs) == len(self.get_inputs_desc()), \
assert len(args) == len(self.get_inputs_desc()), \
"Number of inputs passed to the graph != number of inputs defined " \
"in ModelDesc! ({} != {})".format(len(inputs), len(self.get_inputs_desc()))
return self._build_graph(inputs)
"in ModelDesc! ({} != {})".format(len(args), len(self.get_inputs_desc()))
log_deprecated(
"ModelDescBase._build_graph() interface",
"Use build_graph() instead!",
"2019-03-30")
return self._build_graph(args)
def _build_graph(self, inputs):
"""
......@@ -187,6 +191,10 @@ class ModelDesc(ModelDescBase):
and applies the collection
``tf.GraphKeys.REGULARIZATION_LOSSES`` to the cost automatically.
"""
log_deprecated(
"get_cost() and self.cost",
"Return the cost tensor directly in build_graph() instead!",
"2019-03-30")
cost = self._get_cost()
reg_cost = regularize_cost_from_collection()
if reg_cost.op.type != 'Const':
......@@ -211,7 +219,12 @@ class ModelDesc(ModelDescBase):
a :class:`tf.train.Optimizer` instance.
"""
try:
return self._get_optimizer()
ret = self._get_optimizer()
log_deprecated(
"ModelDescBase._get_optimizer() interface",
"Use optimizer() instead!",
"2019-03-30")
return ret
except NotImplementedError:
pass
return self.optimizer()
......
......@@ -263,9 +263,9 @@ class TowerFuncWrapper(object):
They are used to figure out the names for the input tensors.
"""
assert callable(tower_fn), tower_fn
inputs_desc_names = [k.name for k in inputs_desc]
assert len(set(inputs_desc_names)) == len(inputs_desc_names), \
"Duplicated names in inputs_desc! " + str(inputs_desc_names)
self._inputs_desc_names = [k.name for k in inputs_desc]
assert len(set(self._inputs_desc_names)) == len(self._inputs_desc_names), \
"Duplicated names in inputs_desc! " + str(self._inputs_desc_names)
self._tower_fn = tower_fn
self._inputs_desc = inputs_desc
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment