Commit 7bb9ee1c authored by Yuxin Wu's avatar Yuxin Wu

fix build

parent cef8ae29
...@@ -68,10 +68,11 @@ class GANTrainer(TowerTrainer): ...@@ -68,10 +68,11 @@ class GANTrainer(TowerTrainer):
def __init__(self, input, model): def __init__(self, input, model):
super(GANTrainer, self).__init__() super(GANTrainer, self).__init__()
assert isinstance(model, GANModelDesc), model assert isinstance(model, GANModelDesc), model
cbs = input.setup(model.get_inputs_desc()) inputs_desc = model.get_inputs_desc()
cbs = input.setup(inputs_desc)
tower_func = TowerFuncWrapper( tower_func = TowerFuncWrapper(
model.build_graph, model.get_inputs_desc()) model.build_graph, inputs_desc)
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
tower_func(*input.get_input_tensors()) tower_func(*input.get_input_tensors())
opt = model.get_optimizer() opt = model.get_optimizer()
......
...@@ -94,7 +94,7 @@ class ModelDescBase(object): ...@@ -94,7 +94,7 @@ class ModelDescBase(object):
args (list[tf.Tensor]): a list of tensors, args (list[tf.Tensor]): a list of tensors,
that match the list of :class:`InputDesc` defined by ``_get_inputs``. that match the list of :class:`InputDesc` defined by ``_get_inputs``.
""" """
if len(args) == 0: if len(args) == 1:
arg = args[0] arg = args[0]
if isinstance(arg, InputSource): if isinstance(arg, InputSource):
inputs = arg.get_input_tensors() # remove in the future? inputs = arg.get_input_tensors() # remove in the future?
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment