Commit cef8ae29 authored by Yuxin Wu's avatar Yuxin Wu

Fix main script in all GAN examples

parent 5bd3c395
......@@ -105,18 +105,6 @@ def get_data():
return BatchData(ds, BATCH)
def get_config():
logger.auto_set_dir()
dataset = get_data()
return TrainConfig(
dataflow=dataset,
callbacks=[ModelSaver()],
model=Model(),
steps_per_epoch=500,
max_epoch=100,
)
def sample(model_path):
pred = PredictConfig(
session_init=get_model_loader(model_path),
......@@ -145,7 +133,12 @@ if __name__ == '__main__':
if args.sample:
sample(args.load)
else:
config = get_config()
logger.auto_set_dir()
config = TrainConfig(
callbacks=[ModelSaver()],
steps_per_epoch=500,
max_epoch=100,
)
if args.load:
config.session_init = SaverRestore(args.load)
GANTrainer(config).train()
GANTrainer(QueueInput(get_data()), Model()).train_with_config(config)
......@@ -219,8 +219,6 @@ if __name__ == '__main__':
data = PrintData(data)
config = TrainConfig(
model=Model(),
dataflow=data,
callbacks=[
ModelSaver(),
ScheduledHyperParamSetter(
......@@ -229,7 +227,8 @@ if __name__ == '__main__':
PeriodicTrigger(VisualizeTestSet(), every_k_epochs=3),
],
max_epoch=195,
steps_per_epoch=data.size(),
session_init=SaverRestore(args.load) if args.load else None
)
GANTrainer(config).train()
GANTrainer(QueueInput(data), Model()).train_with_config(config)
......@@ -218,8 +218,6 @@ if __name__ == '__main__':
data = get_celebA_data(args.data, args.style_A, args.style_B)
config = TrainConfig(
model=Model(),
dataflow=data,
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=250,
......@@ -227,4 +225,5 @@ if __name__ == '__main__':
)
# train 1 D after 2 G
SeparateGANTrainer(config, d_period=3).train()
SeparateGANTrainer(
QueueInput(data), Model(), d_period=3).train_with_config(config)
......@@ -70,9 +70,10 @@ class GANTrainer(TowerTrainer):
assert isinstance(model, GANModelDesc), model
cbs = input.setup(model.get_inputs_desc())
tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
tower_func = TowerFuncWrapper(
model.build_graph, model.get_inputs_desc())
with TowerContext('', is_training=True):
tower_func(input)
tower_func(*input.get_input_tensors())
opt = model.get_optimizer()
# by default, run one d_min after one g_min
......@@ -103,7 +104,7 @@ class SeparateGANTrainer(TowerTrainer):
cbs = input.setup(model.get_inputs_desc())
tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
with TowerContext('', is_training=True):
tower_func(input)
tower_func(*input.get_input_tensors())
opt = model.get_optimizer()
with tf.name_scope('optimize'):
......
......@@ -169,21 +169,6 @@ def get_data():
return ds
def get_config():
logger.auto_set_dir()
dataset = get_data()
return TrainConfig(
dataflow=dataset,
callbacks=[
PeriodicTrigger(ModelSaver(), every_k_epochs=3),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
],
model=Model(),
steps_per_epoch=dataset.size(),
max_epoch=300,
)
def sample(datadir, model_path):
pred = PredictConfig(
session_init=get_model_loader(model_path),
......@@ -219,9 +204,21 @@ if __name__ == '__main__':
BATCH = args.batch
if args.sample:
assert args.load
sample(args.data, args.load)
else:
config = get_config()
logger.auto_set_dir()
data = QueueInput(get_data())
config = TrainConfig(
callbacks=[
PeriodicTrigger(ModelSaver(), every_k_epochs=3),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
],
steps_per_epoch=data.size(),
max_epoch=300,
)
if args.load:
config.session_init = SaverRestore(args.load)
GANTrainer(config).train()
GANTrainer(data, Model()).train_with_config(config)
......@@ -96,11 +96,11 @@ if __name__ == '__main__':
assert args.data
logger.auto_set_dir()
config = TrainConfig(
model=Model(),
dataflow=DCGAN.get_data(args.data),
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
)
SeparateGANTrainer(config, g_period=6).train()
SeparateGANTrainer(
QueueInput(DCGAN.get_data(args.data)),
Model(), g_period=6).train_with_config(config)
......@@ -190,17 +190,6 @@ def get_data():
return ds
def get_config():
logger.auto_set_dir('d')
return TrainConfig(
dataflow=get_data(),
callbacks=[ModelSaver(keep_freq=0.1)],
model=Model(),
steps_per_epoch=500,
max_epoch=100,
)
def sample(model_path):
pred = OfflinePredictor(PredictConfig(
session_init=get_model_loader(model_path),
......@@ -255,7 +244,12 @@ if __name__ == '__main__':
BATCH = 100
sample(args.load)
else:
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
GANTrainer(config).train()
logger.auto_set_dir()
cfg = TrainConfig(
callbacks=[ModelSaver(keep_freq=0.1)],
steps_per_epoch=500,
max_epoch=100,
session_init=SaverRestore(args.load) if args.load else None
)
GANTrainer(QueueInput(get_data()),
Model()).train_with_config(cfg)
......@@ -86,16 +86,25 @@ class ModelDescBase(object):
:returns: a list of InputDesc
"""
def build_graph(self, inputs):
def build_graph(self, *args):
"""
Build the whole symbolic graph.
Args:
inputs (list[tf.Tensor]): a list of tensors,
args (list[tf.Tensor]): a list of tensors,
that match the list of :class:`InputDesc` defined by ``_get_inputs``.
"""
if isinstance(inputs, InputSource):
inputs = inputs.get_input_tensors()
if len(args) == 0:
arg = args[0]
if isinstance(arg, InputSource):
inputs = arg.get_input_tensors() # remove in the future?
if isinstance(arg, (list, tuple)):
inputs = arg
else:
inputs = [arg]
else:
inputs = args
assert len(inputs) == len(self.get_inputs_desc()), \
"Number of inputs passed to the graph != number of inputs defined " \
"in ModelDesc! ({} != {})".format(len(inputs), len(self.get_inputs_desc()))
......@@ -148,14 +157,11 @@ class ModelDesc(ModelDescBase):
def _get_optimizer(self):
raise NotImplementedError()
def build_graph_get_cost(self, *inputs):
"""
Build the graph from inputs and return the cost tensor.
"""
def _build_graph_get_cost(self, *inputs):
self.build_graph(inputs)
return self.get_cost()
def build_graph_get_grads(self, *inputs):
def _build_graph_get_grads(self, *inputs):
"""
Build the graph from inputs and return the grads.
This is useful for most of the :class:`GraphBuilder` which expects such a function.
......
......@@ -78,7 +78,7 @@ def launch_train_with_config(config, trainer):
trainer.setup_graph(
inputs_desc, input,
model.build_graph_get_cost, model.get_optimizer)
model._build_graph_get_cost, model.get_optimizer)
trainer.train(
config.callbacks, config.monitors,
config.session_creator, config.session_init,
......
......@@ -64,7 +64,7 @@ class DistributedTrainerReplicated(Trainer):
self._config.callbacks.extend(cbs)
self.train_op, initial_sync_op, model_sync_op = self._builder.build(
lambda: self.model.build_graph_get_grads(
lambda: self.model._build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
......
......@@ -70,7 +70,7 @@ class SyncMultiGPUTrainerParameterServer(Trainer):
self.train_op = SyncMultiGPUParameterServerBuilder(
self._config.tower, self._ps_device).build(
lambda: self.model.build_graph_get_grads(
lambda: self.model._build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
......@@ -104,7 +104,7 @@ class SyncMultiGPUTrainerReplicated(Trainer):
self.train_op, post_init_op = SyncMultiGPUReplicatedBuilder(
self._config.tower).build(
lambda: self.model.build_graph_get_grads(
lambda: self.model._build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
......@@ -134,7 +134,7 @@ class AsyncMultiGPUTrainer(Trainer):
self.train_op = AsyncMultiGPUBuilder(
self._config.tower, self._scale_gradient).build(
lambda: self.model.build_graph_get_grads(
lambda: self.model._build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
......
......@@ -43,7 +43,7 @@ class SimpleTrainer(Trainer):
cbs = self._input_source.setup(self.model.get_inputs_desc())
with TowerContext('', is_training=True):
grads = self.model.build_graph_get_grads(
grads = self.model._build_graph_get_grads(
*self._input_source.get_input_tensors())
opt = self.model.get_optimizer()
self.train_op = opt.apply_gradients(grads, name='min_op')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment