Commit a0324a28 authored by Yuxin Wu's avatar Yuxin Wu

use `build_graph` with positional args

parent fa0d4dc6
......@@ -140,7 +140,7 @@ class MultiGPUGANTrainer(TowerTrainer):
# build the graph
def get_cost(*inputs):
model.build_graph(inputs)
model.build_graph(*inputs)
return [model.d_loss, model.g_loss]
tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc())
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
......
......@@ -200,7 +200,7 @@ if __name__ == '__main__':
input = PlaceholderInput()
input.setup(input_desc)
with TowerContext('', is_training=True):
model.build_graph(input)
model.build_graph(*input.get_input_tensors())
tf.profiler.profile(
tf.get_default_graph(),
......
......@@ -27,7 +27,7 @@ with tf.Graph().as_default() as G:
with TowerContext('', is_training=False):
input = PlaceholderInput()
input.setup(M.get_inputs_desc())
M.build_graph(input)
M.build_graph(*input.get_input_tensors())
else:
tf.train.import_meta_graph(args.meta)
......
......@@ -95,10 +95,13 @@ class ModelDescBase(object):
def build_graph(self, *args):
"""
Build the whole symbolic graph.
This is supposed to be the "tower function" when used with :class:`TowerTrainer`.
By default it will call :meth:`_build_graph`
with a list of input tensors.
Args:
args ([tf.Tensor]): a list of tensors,
that matches the list of :class:`InputDesc` defined by ``_get_inputs``.
args ([tf.Tensor]): tensors that matches the list of
:class:`InputDesc` defined by ``_get_inputs``.
"""
if len(args) == 1:
arg = args[0]
......@@ -118,8 +121,10 @@ class ModelDescBase(object):
"in ModelDesc! ({} != {})".format(len(inputs), len(self.get_inputs_desc()))
self._build_graph(inputs)
@abstractmethod
def _build_graph(self, inputs):
"""
This is an old interface which takes a list of tensors, instead of positional arguments.
"""
pass
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment