Commit 7b0782d6 authored by Yuxin Wu's avatar Yuxin Wu

Clearly define the argument of `build_graph`

parent 4cc00393
## DEPRECATED
Please use gym or other APIs.
......@@ -9,6 +9,7 @@ import tensorflow as tf
import six
from ..utils.argtools import memoized
from ..utils.develop import log_deprecated
from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import get_current_tower_context
from ..input_source import InputSource
......@@ -96,15 +97,17 @@ class ModelDescBase(object):
Build the whole symbolic graph.
Args:
args (list[tf.Tensor]): a list of tensors,
that match the list of :class:`InputDesc` defined by ``_get_inputs``.
args ([tf.Tensor]): a list of tensors,
that matches the list of :class:`InputDesc` defined by ``_get_inputs``.
"""
if len(args) == 1:
arg = args[0]
if isinstance(arg, InputSource):
inputs = arg.get_input_tensors() # remove in the future?
log_deprecated("build_graph(InputSource)", "Call with tensors in positional args instead.")
elif isinstance(arg, (list, tuple)):
inputs = arg
log_deprecated("build_graph([Tensor])", "Call with positional args instead.")
else:
inputs = [arg]
else:
......@@ -163,7 +166,7 @@ class ModelDesc(ModelDescBase):
raise NotImplementedError()
def _build_graph_get_cost(self, *inputs):
self.build_graph(inputs)
self.build_graph(*inputs)
return self.get_cost()
def _build_graph_get_grads(self, *inputs):
......
......@@ -89,7 +89,7 @@ class ModelExport(object):
"""
logger.info('[export] build model for %s' % checkpoint)
with TowerContext('', is_training=False):
self.model.build_graph(self.input)
self.model.build_graph(*self.input.get_input_tensors())
self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
# load values from latest checkpoint
......@@ -129,8 +129,6 @@ class ModelExport(object):
outputs=outputs_signature,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
# legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
self.sess, tags,
signature_def_map={signature_name: prediction_signature})
......
......@@ -86,7 +86,7 @@ class Trainer(object):
self._config = config
self.inputs_desc = config.model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(
lambda *inputs: config.model.build_graph(inputs),
lambda *inputs: config.model.build_graph(*inputs),
self.inputs_desc)
self._main_tower_vs_name = ""
......
......@@ -123,7 +123,7 @@ class Trainer(object):
if self.model is not None:
def f(*inputs):
self.model.build_graph(inputs)
self.model.build_graph(*inputs)
"""
Only to mimic new trainer interafce on inference.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment