Commit a0324a28 authored by Yuxin Wu's avatar Yuxin Wu

use `build_graph` with positional args

parent fa0d4dc6
...@@ -140,7 +140,7 @@ class MultiGPUGANTrainer(TowerTrainer): ...@@ -140,7 +140,7 @@ class MultiGPUGANTrainer(TowerTrainer):
# build the graph # build the graph
def get_cost(*inputs): def get_cost(*inputs):
model.build_graph(inputs) model.build_graph(*inputs)
return [model.d_loss, model.g_loss] return [model.d_loss, model.g_loss]
tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc()) tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc())
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
......
...@@ -200,7 +200,7 @@ if __name__ == '__main__': ...@@ -200,7 +200,7 @@ if __name__ == '__main__':
input = PlaceholderInput() input = PlaceholderInput()
input.setup(input_desc) input.setup(input_desc)
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
model.build_graph(input) model.build_graph(*input.get_input_tensors())
tf.profiler.profile( tf.profiler.profile(
tf.get_default_graph(), tf.get_default_graph(),
......
...@@ -27,7 +27,7 @@ with tf.Graph().as_default() as G: ...@@ -27,7 +27,7 @@ with tf.Graph().as_default() as G:
with TowerContext('', is_training=False): with TowerContext('', is_training=False):
input = PlaceholderInput() input = PlaceholderInput()
input.setup(M.get_inputs_desc()) input.setup(M.get_inputs_desc())
M.build_graph(input) M.build_graph(*input.get_input_tensors())
else: else:
tf.train.import_meta_graph(args.meta) tf.train.import_meta_graph(args.meta)
......
...@@ -95,10 +95,13 @@ class ModelDescBase(object): ...@@ -95,10 +95,13 @@ class ModelDescBase(object):
def build_graph(self, *args): def build_graph(self, *args):
""" """
Build the whole symbolic graph. Build the whole symbolic graph.
This is supposed to be the "tower function" when used with :class:`TowerTrainer`.
By default it will call :meth:`_build_graph`
with a list of input tensors.
Args: Args:
args ([tf.Tensor]): a list of tensors, args ([tf.Tensor]): tensors that matches the list of
that matches the list of :class:`InputDesc` defined by ``_get_inputs``. :class:`InputDesc` defined by ``_get_inputs``.
""" """
if len(args) == 1: if len(args) == 1:
arg = args[0] arg = args[0]
...@@ -118,8 +121,10 @@ class ModelDescBase(object): ...@@ -118,8 +121,10 @@ class ModelDescBase(object):
"in ModelDesc! ({} != {})".format(len(inputs), len(self.get_inputs_desc())) "in ModelDesc! ({} != {})".format(len(inputs), len(self.get_inputs_desc()))
self._build_graph(inputs) self._build_graph(inputs)
@abstractmethod
def _build_graph(self, inputs): def _build_graph(self, inputs):
"""
This is an old interface which takes a list of tensors, instead of positional arguments.
"""
pass pass
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment