Commit 2870347c authored by Yuxin Wu's avatar Yuxin Wu

Register a list of callbacks at a time

parent 712ea325
......@@ -68,15 +68,24 @@ class GANTrainer(TowerTrainer):
super(GANTrainer, self).__init__()
assert isinstance(model, GANModelDesc), model
inputs_desc = model.get_inputs_desc()
# Setup input
cbs = input.setup(inputs_desc)
self.register_callback(cbs)
# we need to set towerfunc because it's a TowerTrainer,
# and only TowerTrainer supports automatic graph creation for inference during training.
"""
We need to set tower_func because it's a TowerTrainer,
and only TowerTrainer supports automatic graph creation for inference during training.
If we don't care about inference during training, using tower_func is
not needed. Just calling model.build_graph directly is OK.
"""
# Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, inputs_desc)
with TowerContext('', is_training=True):
self.tower_func(*input.get_input_tensors())
opt = model.get_optimizer()
# Define the training iteration
# by default, run one d_min after one g_min
with tf.name_scope('optimize'):
g_min = opt.minimize(model.g_loss, var_list=model.g_vars, name='g_op')
......@@ -84,9 +93,6 @@ class GANTrainer(TowerTrainer):
d_min = opt.minimize(model.d_loss, var_list=model.d_vars, name='d_op')
self.train_op = d_min
for cb in cbs:
self.register_callback(cb)
class SeparateGANTrainer(TowerTrainer):
""" A GAN trainer which runs two optimization ops with a certain ratio."""
......@@ -101,7 +107,11 @@ class SeparateGANTrainer(TowerTrainer):
self._g_period = int(g_period)
assert min(d_period, g_period) == 1
# Setup input
cbs = input.setup(model.get_inputs_desc())
self.register_callback(cbs)
# Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
with TowerContext('', is_training=True):
self.tower_func(*input.get_input_tensors())
......@@ -113,10 +123,8 @@ class SeparateGANTrainer(TowerTrainer):
self.g_min = opt.minimize(
model.g_loss, var_list=model.g_vars, name='g_min')
for cb in cbs:
self.register_callback(cb)
def run_step(self):
# Define the training iteration
if self.global_step % (self._d_period) == 0:
self.hooked_sess.run(self.d_min)
if self.global_step % (self._g_period) == 0:
......@@ -132,11 +140,12 @@ class MultiGPUGANTrainer(TowerTrainer):
assert nr_gpu > 1
raw_devices = ['/gpu:{}'.format(k) for k in range(nr_gpu)]
# setup input
# Setup input
input = StagingInput(input, list(range(nr_gpu)))
cbs = input.setup(model.get_inputs_desc())
self.register_callback(cbs)
# build the graph
# Build the graph with multi-gpu replication
def get_cost(*inputs):
model.build_graph(*inputs)
return [model.d_loss, model.g_loss]
......@@ -146,7 +155,7 @@ class MultiGPUGANTrainer(TowerTrainer):
list(range(nr_gpu)),
lambda: self.tower_func(*input.get_input_tensors()),
devices)
# simply average the cost. It might get faster to average the gradients
# Simply average the cost here. It might be faster to average the gradients
with tf.name_scope('optimize'):
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu)
g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / nr_gpu)
......@@ -158,9 +167,8 @@ class MultiGPUGANTrainer(TowerTrainer):
with tf.control_dependencies([g_min]):
d_min = opt.minimize(d_loss, var_list=model.d_vars,
colocate_gradients_with_ops=True, name='d_op')
# Define the training iteration
self.train_op = d_min
for cb in cbs:
self.register_callback(cb)
class RandomZData(DataFlow):
......
......@@ -135,9 +135,16 @@ class Trainer(object):
def _register_callback(self, cb):
"""
Register a callback to the trainer.
Register callbacks to the trainer.
It can only be called before :meth:`Trainer.train()`.
Args:
cb (Callback or [Callback]): a callback or a list of callbacks
"""
if isinstance(cb, (list, tuple)):
for x in cb:
self._register_callback(x)
return
assert isinstance(cb, Callback), cb
assert not isinstance(self._callbacks, Callbacks), \
"Cannot register more callbacks after trainer was setup!"
......
......@@ -145,9 +145,7 @@ class SingleCostTrainer(TowerTrainer):
# TODO setup may want to register monitor as well??
input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
internal_callbacks = input_callbacks + train_callbacks
for cb in internal_callbacks:
self.register_callback(cb)
self.register_callback(input_callbacks + train_callbacks)
@abstractmethod
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment