Commit 2870347c authored by Yuxin Wu's avatar Yuxin Wu

Register a list of callbacks at a time

parent 712ea325
...@@ -68,15 +68,24 @@ class GANTrainer(TowerTrainer): ...@@ -68,15 +68,24 @@ class GANTrainer(TowerTrainer):
super(GANTrainer, self).__init__() super(GANTrainer, self).__init__()
assert isinstance(model, GANModelDesc), model assert isinstance(model, GANModelDesc), model
inputs_desc = model.get_inputs_desc() inputs_desc = model.get_inputs_desc()
# Setup input
cbs = input.setup(inputs_desc) cbs = input.setup(inputs_desc)
self.register_callback(cbs)
# we need to set towerfunc because it's a TowerTrainer, """
# and only TowerTrainer supports automatic graph creation for inference during training. We need to set tower_func because it's a TowerTrainer,
and only TowerTrainer supports automatic graph creation for inference during training.
If we don't care about inference during training, using tower_func is
not needed. Just calling model.build_graph directly is OK.
"""
# Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, inputs_desc) self.tower_func = TowerFuncWrapper(model.build_graph, inputs_desc)
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
self.tower_func(*input.get_input_tensors()) self.tower_func(*input.get_input_tensors())
opt = model.get_optimizer() opt = model.get_optimizer()
# Define the training iteration
# by default, run one d_min after one g_min # by default, run one d_min after one g_min
with tf.name_scope('optimize'): with tf.name_scope('optimize'):
g_min = opt.minimize(model.g_loss, var_list=model.g_vars, name='g_op') g_min = opt.minimize(model.g_loss, var_list=model.g_vars, name='g_op')
...@@ -84,9 +93,6 @@ class GANTrainer(TowerTrainer): ...@@ -84,9 +93,6 @@ class GANTrainer(TowerTrainer):
d_min = opt.minimize(model.d_loss, var_list=model.d_vars, name='d_op') d_min = opt.minimize(model.d_loss, var_list=model.d_vars, name='d_op')
self.train_op = d_min self.train_op = d_min
for cb in cbs:
self.register_callback(cb)
class SeparateGANTrainer(TowerTrainer): class SeparateGANTrainer(TowerTrainer):
""" A GAN trainer which runs two optimization ops with a certain ratio.""" """ A GAN trainer which runs two optimization ops with a certain ratio."""
...@@ -101,7 +107,11 @@ class SeparateGANTrainer(TowerTrainer): ...@@ -101,7 +107,11 @@ class SeparateGANTrainer(TowerTrainer):
self._g_period = int(g_period) self._g_period = int(g_period)
assert min(d_period, g_period) == 1 assert min(d_period, g_period) == 1
# Setup input
cbs = input.setup(model.get_inputs_desc()) cbs = input.setup(model.get_inputs_desc())
self.register_callback(cbs)
# Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc()) self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
self.tower_func(*input.get_input_tensors()) self.tower_func(*input.get_input_tensors())
...@@ -113,10 +123,8 @@ class SeparateGANTrainer(TowerTrainer): ...@@ -113,10 +123,8 @@ class SeparateGANTrainer(TowerTrainer):
self.g_min = opt.minimize( self.g_min = opt.minimize(
model.g_loss, var_list=model.g_vars, name='g_min') model.g_loss, var_list=model.g_vars, name='g_min')
for cb in cbs:
self.register_callback(cb)
def run_step(self): def run_step(self):
# Define the training iteration
if self.global_step % (self._d_period) == 0: if self.global_step % (self._d_period) == 0:
self.hooked_sess.run(self.d_min) self.hooked_sess.run(self.d_min)
if self.global_step % (self._g_period) == 0: if self.global_step % (self._g_period) == 0:
...@@ -132,11 +140,12 @@ class MultiGPUGANTrainer(TowerTrainer): ...@@ -132,11 +140,12 @@ class MultiGPUGANTrainer(TowerTrainer):
assert nr_gpu > 1 assert nr_gpu > 1
raw_devices = ['/gpu:{}'.format(k) for k in range(nr_gpu)] raw_devices = ['/gpu:{}'.format(k) for k in range(nr_gpu)]
# setup input # Setup input
input = StagingInput(input, list(range(nr_gpu))) input = StagingInput(input, list(range(nr_gpu)))
cbs = input.setup(model.get_inputs_desc()) cbs = input.setup(model.get_inputs_desc())
self.register_callback(cbs)
# build the graph # Build the graph with multi-gpu replication
def get_cost(*inputs): def get_cost(*inputs):
model.build_graph(*inputs) model.build_graph(*inputs)
return [model.d_loss, model.g_loss] return [model.d_loss, model.g_loss]
...@@ -146,7 +155,7 @@ class MultiGPUGANTrainer(TowerTrainer): ...@@ -146,7 +155,7 @@ class MultiGPUGANTrainer(TowerTrainer):
list(range(nr_gpu)), list(range(nr_gpu)),
lambda: self.tower_func(*input.get_input_tensors()), lambda: self.tower_func(*input.get_input_tensors()),
devices) devices)
# simply average the cost. It might get faster to average the gradients # Simply average the cost here. It might be faster to average the gradients
with tf.name_scope('optimize'): with tf.name_scope('optimize'):
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu) d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu)
g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / nr_gpu) g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / nr_gpu)
...@@ -158,9 +167,8 @@ class MultiGPUGANTrainer(TowerTrainer): ...@@ -158,9 +167,8 @@ class MultiGPUGANTrainer(TowerTrainer):
with tf.control_dependencies([g_min]): with tf.control_dependencies([g_min]):
d_min = opt.minimize(d_loss, var_list=model.d_vars, d_min = opt.minimize(d_loss, var_list=model.d_vars,
colocate_gradients_with_ops=True, name='d_op') colocate_gradients_with_ops=True, name='d_op')
# Define the training iteration
self.train_op = d_min self.train_op = d_min
for cb in cbs:
self.register_callback(cb)
class RandomZData(DataFlow): class RandomZData(DataFlow):
......
...@@ -135,9 +135,16 @@ class Trainer(object): ...@@ -135,9 +135,16 @@ class Trainer(object):
def _register_callback(self, cb): def _register_callback(self, cb):
""" """
Register a callback to the trainer. Register callbacks to the trainer.
It can only be called before :meth:`Trainer.train()`. It can only be called before :meth:`Trainer.train()`.
Args:
cb (Callback or [Callback]): a callback or a list of callbacks
""" """
if isinstance(cb, (list, tuple)):
for x in cb:
self._register_callback(x)
return
assert isinstance(cb, Callback), cb assert isinstance(cb, Callback), cb
assert not isinstance(self._callbacks, Callbacks), \ assert not isinstance(self._callbacks, Callbacks), \
"Cannot register more callbacks after trainer was setup!" "Cannot register more callbacks after trainer was setup!"
......
...@@ -145,9 +145,7 @@ class SingleCostTrainer(TowerTrainer): ...@@ -145,9 +145,7 @@ class SingleCostTrainer(TowerTrainer):
# TODO setup may want to register monitor as well?? # TODO setup may want to register monitor as well??
input_callbacks = self._setup_input(inputs_desc, input) input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn) train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
internal_callbacks = input_callbacks + train_callbacks self.register_callback(input_callbacks + train_callbacks)
for cb in internal_callbacks:
self.register_callback(cb)
@abstractmethod @abstractmethod
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment