Commit cef8ae29 authored by Yuxin Wu's avatar Yuxin Wu

Fix main script in all GAN examples

parent 5bd3c395
...@@ -105,18 +105,6 @@ def get_data(): ...@@ -105,18 +105,6 @@ def get_data():
return BatchData(ds, BATCH) return BatchData(ds, BATCH)
def get_config():
logger.auto_set_dir()
dataset = get_data()
return TrainConfig(
dataflow=dataset,
callbacks=[ModelSaver()],
model=Model(),
steps_per_epoch=500,
max_epoch=100,
)
def sample(model_path): def sample(model_path):
pred = PredictConfig( pred = PredictConfig(
session_init=get_model_loader(model_path), session_init=get_model_loader(model_path),
...@@ -145,7 +133,12 @@ if __name__ == '__main__': ...@@ -145,7 +133,12 @@ if __name__ == '__main__':
if args.sample: if args.sample:
sample(args.load) sample(args.load)
else: else:
config = get_config() logger.auto_set_dir()
config = TrainConfig(
callbacks=[ModelSaver()],
steps_per_epoch=500,
max_epoch=100,
)
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
GANTrainer(config).train() GANTrainer(QueueInput(get_data()), Model()).train_with_config(config)
...@@ -219,8 +219,6 @@ if __name__ == '__main__': ...@@ -219,8 +219,6 @@ if __name__ == '__main__':
data = PrintData(data) data = PrintData(data)
config = TrainConfig( config = TrainConfig(
model=Model(),
dataflow=data,
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter( ScheduledHyperParamSetter(
...@@ -229,7 +227,8 @@ if __name__ == '__main__': ...@@ -229,7 +227,8 @@ if __name__ == '__main__':
PeriodicTrigger(VisualizeTestSet(), every_k_epochs=3), PeriodicTrigger(VisualizeTestSet(), every_k_epochs=3),
], ],
max_epoch=195, max_epoch=195,
steps_per_epoch=data.size(),
session_init=SaverRestore(args.load) if args.load else None session_init=SaverRestore(args.load) if args.load else None
) )
GANTrainer(config).train() GANTrainer(QueueInput(data), Model()).train_with_config(config)
...@@ -218,8 +218,6 @@ if __name__ == '__main__': ...@@ -218,8 +218,6 @@ if __name__ == '__main__':
data = get_celebA_data(args.data, args.style_A, args.style_B) data = get_celebA_data(args.data, args.style_A, args.style_B)
config = TrainConfig( config = TrainConfig(
model=Model(),
dataflow=data,
callbacks=[ModelSaver()], callbacks=[ModelSaver()],
steps_per_epoch=300, steps_per_epoch=300,
max_epoch=250, max_epoch=250,
...@@ -227,4 +225,5 @@ if __name__ == '__main__': ...@@ -227,4 +225,5 @@ if __name__ == '__main__':
) )
# train 1 D after 2 G # train 1 D after 2 G
SeparateGANTrainer(config, d_period=3).train() SeparateGANTrainer(
QueueInput(data), Model(), d_period=3).train_with_config(config)
...@@ -70,9 +70,10 @@ class GANTrainer(TowerTrainer): ...@@ -70,9 +70,10 @@ class GANTrainer(TowerTrainer):
assert isinstance(model, GANModelDesc), model assert isinstance(model, GANModelDesc), model
cbs = input.setup(model.get_inputs_desc()) cbs = input.setup(model.get_inputs_desc())
tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc()) tower_func = TowerFuncWrapper(
model.build_graph, model.get_inputs_desc())
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
tower_func(input) tower_func(*input.get_input_tensors())
opt = model.get_optimizer() opt = model.get_optimizer()
# by default, run one d_min after one g_min # by default, run one d_min after one g_min
...@@ -103,7 +104,7 @@ class SeparateGANTrainer(TowerTrainer): ...@@ -103,7 +104,7 @@ class SeparateGANTrainer(TowerTrainer):
cbs = input.setup(model.get_inputs_desc()) cbs = input.setup(model.get_inputs_desc())
tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc()) tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
tower_func(input) tower_func(*input.get_input_tensors())
opt = model.get_optimizer() opt = model.get_optimizer()
with tf.name_scope('optimize'): with tf.name_scope('optimize'):
......
...@@ -169,21 +169,6 @@ def get_data(): ...@@ -169,21 +169,6 @@ def get_data():
return ds return ds
def get_config():
logger.auto_set_dir()
dataset = get_data()
return TrainConfig(
dataflow=dataset,
callbacks=[
PeriodicTrigger(ModelSaver(), every_k_epochs=3),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
],
model=Model(),
steps_per_epoch=dataset.size(),
max_epoch=300,
)
def sample(datadir, model_path): def sample(datadir, model_path):
pred = PredictConfig( pred = PredictConfig(
session_init=get_model_loader(model_path), session_init=get_model_loader(model_path),
...@@ -219,9 +204,21 @@ if __name__ == '__main__': ...@@ -219,9 +204,21 @@ if __name__ == '__main__':
BATCH = args.batch BATCH = args.batch
if args.sample: if args.sample:
assert args.load
sample(args.data, args.load) sample(args.data, args.load)
else: else:
config = get_config() logger.auto_set_dir()
data = QueueInput(get_data())
config = TrainConfig(
callbacks=[
PeriodicTrigger(ModelSaver(), every_k_epochs=3),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
],
steps_per_epoch=data.size(),
max_epoch=300,
)
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
GANTrainer(config).train()
GANTrainer(data, Model()).train_with_config(config)
...@@ -96,11 +96,11 @@ if __name__ == '__main__': ...@@ -96,11 +96,11 @@ if __name__ == '__main__':
assert args.data assert args.data
logger.auto_set_dir() logger.auto_set_dir()
config = TrainConfig( config = TrainConfig(
model=Model(),
dataflow=DCGAN.get_data(args.data),
callbacks=[ModelSaver()], callbacks=[ModelSaver()],
steps_per_epoch=300, steps_per_epoch=300,
max_epoch=200, max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None session_init=SaverRestore(args.load) if args.load else None
) )
SeparateGANTrainer(config, g_period=6).train() SeparateGANTrainer(
QueueInput(DCGAN.get_data(args.data)),
Model(), g_period=6).train_with_config(config)
...@@ -190,17 +190,6 @@ def get_data(): ...@@ -190,17 +190,6 @@ def get_data():
return ds return ds
def get_config():
logger.auto_set_dir('d')
return TrainConfig(
dataflow=get_data(),
callbacks=[ModelSaver(keep_freq=0.1)],
model=Model(),
steps_per_epoch=500,
max_epoch=100,
)
def sample(model_path): def sample(model_path):
pred = OfflinePredictor(PredictConfig( pred = OfflinePredictor(PredictConfig(
session_init=get_model_loader(model_path), session_init=get_model_loader(model_path),
...@@ -255,7 +244,12 @@ if __name__ == '__main__': ...@@ -255,7 +244,12 @@ if __name__ == '__main__':
BATCH = 100 BATCH = 100
sample(args.load) sample(args.load)
else: else:
config = get_config() logger.auto_set_dir()
if args.load: cfg = TrainConfig(
config.session_init = SaverRestore(args.load) callbacks=[ModelSaver(keep_freq=0.1)],
GANTrainer(config).train() steps_per_epoch=500,
max_epoch=100,
session_init=SaverRestore(args.load) if args.load else None
)
GANTrainer(QueueInput(get_data()),
Model()).train_with_config(cfg)
...@@ -86,16 +86,25 @@ class ModelDescBase(object): ...@@ -86,16 +86,25 @@ class ModelDescBase(object):
:returns: a list of InputDesc :returns: a list of InputDesc
""" """
def build_graph(self, inputs): def build_graph(self, *args):
""" """
Build the whole symbolic graph. Build the whole symbolic graph.
Args: Args:
inputs (list[tf.Tensor]): a list of tensors, args (list[tf.Tensor]): a list of tensors,
that match the list of :class:`InputDesc` defined by ``_get_inputs``. that match the list of :class:`InputDesc` defined by ``_get_inputs``.
""" """
if isinstance(inputs, InputSource): if len(args) == 0:
inputs = inputs.get_input_tensors() arg = args[0]
if isinstance(arg, InputSource):
inputs = arg.get_input_tensors() # remove in the future?
if isinstance(arg, (list, tuple)):
inputs = arg
else:
inputs = [arg]
else:
inputs = args
assert len(inputs) == len(self.get_inputs_desc()), \ assert len(inputs) == len(self.get_inputs_desc()), \
"Number of inputs passed to the graph != number of inputs defined " \ "Number of inputs passed to the graph != number of inputs defined " \
"in ModelDesc! ({} != {})".format(len(inputs), len(self.get_inputs_desc())) "in ModelDesc! ({} != {})".format(len(inputs), len(self.get_inputs_desc()))
...@@ -148,14 +157,11 @@ class ModelDesc(ModelDescBase): ...@@ -148,14 +157,11 @@ class ModelDesc(ModelDescBase):
def _get_optimizer(self): def _get_optimizer(self):
raise NotImplementedError() raise NotImplementedError()
def build_graph_get_cost(self, *inputs): def _build_graph_get_cost(self, *inputs):
"""
Build the graph from inputs and return the cost tensor.
"""
self.build_graph(inputs) self.build_graph(inputs)
return self.get_cost() return self.get_cost()
def build_graph_get_grads(self, *inputs): def _build_graph_get_grads(self, *inputs):
""" """
Build the graph from inputs and return the grads. Build the graph from inputs and return the grads.
This is useful for most of the :class:`GraphBuilder` which expects such a function. This is useful for most of the :class:`GraphBuilder` which expects such a function.
......
...@@ -78,7 +78,7 @@ def launch_train_with_config(config, trainer): ...@@ -78,7 +78,7 @@ def launch_train_with_config(config, trainer):
trainer.setup_graph( trainer.setup_graph(
inputs_desc, input, inputs_desc, input,
model.build_graph_get_cost, model.get_optimizer) model._build_graph_get_cost, model.get_optimizer)
trainer.train( trainer.train(
config.callbacks, config.monitors, config.callbacks, config.monitors,
config.session_creator, config.session_init, config.session_creator, config.session_init,
......
...@@ -64,7 +64,7 @@ class DistributedTrainerReplicated(Trainer): ...@@ -64,7 +64,7 @@ class DistributedTrainerReplicated(Trainer):
self._config.callbacks.extend(cbs) self._config.callbacks.extend(cbs)
self.train_op, initial_sync_op, model_sync_op = self._builder.build( self.train_op, initial_sync_op, model_sync_op = self._builder.build(
lambda: self.model.build_graph_get_grads( lambda: self.model._build_graph_get_grads(
*self._input_source.get_input_tensors()), *self._input_source.get_input_tensors()),
self.model.get_optimizer) self.model.get_optimizer)
......
...@@ -70,7 +70,7 @@ class SyncMultiGPUTrainerParameterServer(Trainer): ...@@ -70,7 +70,7 @@ class SyncMultiGPUTrainerParameterServer(Trainer):
self.train_op = SyncMultiGPUParameterServerBuilder( self.train_op = SyncMultiGPUParameterServerBuilder(
self._config.tower, self._ps_device).build( self._config.tower, self._ps_device).build(
lambda: self.model.build_graph_get_grads( lambda: self.model._build_graph_get_grads(
*self._input_source.get_input_tensors()), *self._input_source.get_input_tensors()),
self.model.get_optimizer) self.model.get_optimizer)
...@@ -104,7 +104,7 @@ class SyncMultiGPUTrainerReplicated(Trainer): ...@@ -104,7 +104,7 @@ class SyncMultiGPUTrainerReplicated(Trainer):
self.train_op, post_init_op = SyncMultiGPUReplicatedBuilder( self.train_op, post_init_op = SyncMultiGPUReplicatedBuilder(
self._config.tower).build( self._config.tower).build(
lambda: self.model.build_graph_get_grads( lambda: self.model._build_graph_get_grads(
*self._input_source.get_input_tensors()), *self._input_source.get_input_tensors()),
self.model.get_optimizer) self.model.get_optimizer)
...@@ -134,7 +134,7 @@ class AsyncMultiGPUTrainer(Trainer): ...@@ -134,7 +134,7 @@ class AsyncMultiGPUTrainer(Trainer):
self.train_op = AsyncMultiGPUBuilder( self.train_op = AsyncMultiGPUBuilder(
self._config.tower, self._scale_gradient).build( self._config.tower, self._scale_gradient).build(
lambda: self.model.build_graph_get_grads( lambda: self.model._build_graph_get_grads(
*self._input_source.get_input_tensors()), *self._input_source.get_input_tensors()),
self.model.get_optimizer) self.model.get_optimizer)
......
...@@ -43,7 +43,7 @@ class SimpleTrainer(Trainer): ...@@ -43,7 +43,7 @@ class SimpleTrainer(Trainer):
cbs = self._input_source.setup(self.model.get_inputs_desc()) cbs = self._input_source.setup(self.model.get_inputs_desc())
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
grads = self.model.build_graph_get_grads( grads = self.model._build_graph_get_grads(
*self._input_source.get_input_tensors()) *self._input_source.get_input_tensors())
opt = self.model.get_optimizer() opt = self.model.get_optimizer()
self.train_op = opt.apply_gradients(grads, name='min_op') self.train_op = opt.apply_gradients(grads, name='min_op')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment