Commit 7bb9ee1c authored by Yuxin Wu's avatar Yuxin Wu

fix build

parent cef8ae29
......@@ -68,10 +68,11 @@ class GANTrainer(TowerTrainer):
def __init__(self, input, model):
super(GANTrainer, self).__init__()
assert isinstance(model, GANModelDesc), model
cbs = input.setup(model.get_inputs_desc())
inputs_desc = model.get_inputs_desc()
cbs = input.setup(inputs_desc)
tower_func = TowerFuncWrapper(
model.build_graph, model.get_inputs_desc())
model.build_graph, inputs_desc)
with TowerContext('', is_training=True):
tower_func(*input.get_input_tensors())
opt = model.get_optimizer()
......
......@@ -94,7 +94,7 @@ class ModelDescBase(object):
args (list[tf.Tensor]): a list of tensors,
that match the list of :class:`InputDesc` defined by ``_get_inputs``.
"""
if len(args) == 0:
if len(args) == 1:
arg = args[0]
if isinstance(arg, InputSource):
inputs = arg.get_input_tensors() # remove in the future?
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment