Commit 7b0782d6 authored by Yuxin Wu's avatar Yuxin Wu

Clearly define the argument of `build_graph`

parent 4cc00393
## DEPRECATED
Please use gym or other APIs.
...@@ -9,6 +9,7 @@ import tensorflow as tf ...@@ -9,6 +9,7 @@ import tensorflow as tf
import six import six
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..utils.develop import log_deprecated
from ..tfutils.gradproc import FilterNoneGrad from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..input_source import InputSource from ..input_source import InputSource
...@@ -96,15 +97,17 @@ class ModelDescBase(object): ...@@ -96,15 +97,17 @@ class ModelDescBase(object):
Build the whole symbolic graph. Build the whole symbolic graph.
Args: Args:
args (list[tf.Tensor]): a list of tensors, args ([tf.Tensor]): a list of tensors,
that match the list of :class:`InputDesc` defined by ``_get_inputs``. that matches the list of :class:`InputDesc` defined by ``_get_inputs``.
""" """
if len(args) == 1: if len(args) == 1:
arg = args[0] arg = args[0]
if isinstance(arg, InputSource): if isinstance(arg, InputSource):
inputs = arg.get_input_tensors() # remove in the future? inputs = arg.get_input_tensors() # remove in the future?
log_deprecated("build_graph(InputSource)", "Call with tensors in positional args instead.")
elif isinstance(arg, (list, tuple)): elif isinstance(arg, (list, tuple)):
inputs = arg inputs = arg
log_deprecated("build_graph([Tensor])", "Call with positional args instead.")
else: else:
inputs = [arg] inputs = [arg]
else: else:
...@@ -163,7 +166,7 @@ class ModelDesc(ModelDescBase): ...@@ -163,7 +166,7 @@ class ModelDesc(ModelDescBase):
raise NotImplementedError() raise NotImplementedError()
def _build_graph_get_cost(self, *inputs): def _build_graph_get_cost(self, *inputs):
self.build_graph(inputs) self.build_graph(*inputs)
return self.get_cost() return self.get_cost()
def _build_graph_get_grads(self, *inputs): def _build_graph_get_grads(self, *inputs):
......
...@@ -89,7 +89,7 @@ class ModelExport(object): ...@@ -89,7 +89,7 @@ class ModelExport(object):
""" """
logger.info('[export] build model for %s' % checkpoint) logger.info('[export] build model for %s' % checkpoint)
with TowerContext('', is_training=False): with TowerContext('', is_training=False):
self.model.build_graph(self.input) self.model.build_graph(*self.input.get_input_tensors())
self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
# load values from latest checkpoint # load values from latest checkpoint
...@@ -129,8 +129,6 @@ class ModelExport(object): ...@@ -129,8 +129,6 @@ class ModelExport(object):
outputs=outputs_signature, outputs=outputs_signature,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME) method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
# legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables( builder.add_meta_graph_and_variables(
self.sess, tags, self.sess, tags,
signature_def_map={signature_name: prediction_signature}) signature_def_map={signature_name: prediction_signature})
......
...@@ -86,7 +86,7 @@ class Trainer(object): ...@@ -86,7 +86,7 @@ class Trainer(object):
self._config = config self._config = config
self.inputs_desc = config.model.get_inputs_desc() self.inputs_desc = config.model.get_inputs_desc()
self.tower_func = TowerFuncWrapper( self.tower_func = TowerFuncWrapper(
lambda *inputs: config.model.build_graph(inputs), lambda *inputs: config.model.build_graph(*inputs),
self.inputs_desc) self.inputs_desc)
self._main_tower_vs_name = "" self._main_tower_vs_name = ""
......
...@@ -123,7 +123,7 @@ class Trainer(object): ...@@ -123,7 +123,7 @@ class Trainer(object):
if self.model is not None: if self.model is not None:
def f(*inputs): def f(*inputs):
self.model.build_graph(inputs) self.model.build_graph(*inputs)
""" """
Only to mimic new trainer interafce on inference. Only to mimic new trainer interafce on inference.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment