Commit 8419ee3f authored by Yuxin Wu's avatar Yuxin Wu

Let InputSource.setup() returns the callbacks, to simplify trainer implementations

parent 398cb933
...@@ -66,21 +66,24 @@ class GANModelDesc(ModelDescBase): ...@@ -66,21 +66,24 @@ class GANModelDesc(ModelDescBase):
class GANTrainer(Trainer): class GANTrainer(Trainer):
def __init__(self, config): def __init__(self, config):
self._input_source = QueueInput(config.dataflow) input = QueueInput(config.dataflow)
super(GANTrainer, self).__init__(config) model = config.model
cbs = input.setup(model.get_inputs_desc())
config.callbacks.extend(cbs)
def _setup(self):
self._setup_input_source(self._input_source)
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
self.model.build_graph(self._input_source) model.build_graph(input)
opt = self.model.get_optimizer() opt = model.get_optimizer()
# by default, run one d_min after one g_min # by default, run one d_min after one g_min
g_min = opt.minimize(self.model.g_loss, var_list=self.model.g_vars, name='g_op') g_min = opt.minimize(model.g_loss, var_list=model.g_vars, name='g_op')
with tf.control_dependencies([g_min]): with tf.control_dependencies([g_min]):
d_min = opt.minimize(self.model.d_loss, var_list=self.model.d_vars, name='d_op') d_min = opt.minimize(model.d_loss, var_list=model.d_vars, name='d_op')
self.train_op = d_min self.train_op = d_min
super(GANTrainer, self).__init__(config)
class SeparateGANTrainer(Trainer): class SeparateGANTrainer(Trainer):
""" A GAN trainer which runs two optimization ops with a certain ratio, one in each step. """ """ A GAN trainer which runs two optimization ops with a certain ratio, one in each step. """
...@@ -90,30 +93,31 @@ class SeparateGANTrainer(Trainer): ...@@ -90,30 +93,31 @@ class SeparateGANTrainer(Trainer):
d_period(int): period of each d_opt run d_period(int): period of each d_opt run
g_period(int): period of each g_opt run g_period(int): period of each g_opt run
""" """
self._input_source = QueueInput(config.dataflow)
self._d_period = int(d_period) self._d_period = int(d_period)
self._g_period = int(g_period) self._g_period = int(g_period)
assert min(d_period, g_period) == 1 assert min(d_period, g_period) == 1
super(SeparateGANTrainer, self).__init__(config)
def _setup(self): input = QueueInput(config.dataflow)
self._setup_input_source(self._input_source) model = config.model
cbs = input.setup(model.get_inputs_desc())
config.callbacks.extend(cbs)
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
self.model.build_graph(self._input_source) model.build_graph(input)
opt = self.model.get_optimizer() opt = model.get_optimizer()
self.d_min = opt.minimize( self.d_min = opt.minimize(
self.model.d_loss, var_list=self.model.d_vars, name='d_min') model.d_loss, var_list=model.d_vars, name='d_min')
self.g_min = opt.minimize( self.g_min = opt.minimize(
self.model.g_loss, var_list=self.model.g_vars, name='g_min') model.g_loss, var_list=model.g_vars, name='g_min')
self._cnt = 1
super(SeparateGANTrainer, self).__init__(config)
def run_step(self): def run_step(self):
if self._cnt % (self._d_period) == 0: if self.global_step % (self._d_period) == 0:
self.hooked_sess.run(self.d_min) self.hooked_sess.run(self.d_min)
if self._cnt % (self._g_period) == 0: if self.global_step % (self._g_period) == 0:
self.hooked_sess.run(self.g_min) self.hooked_sess.run(self.g_min)
self._cnt += 1
class MultiGPUGANTrainer(Trainer): class MultiGPUGANTrainer(Trainer):
...@@ -121,33 +125,35 @@ class MultiGPUGANTrainer(Trainer): ...@@ -121,33 +125,35 @@ class MultiGPUGANTrainer(Trainer):
A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support. A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support.
""" """
def __init__(self, config): def __init__(self, config):
self._nr_gpu = config.nr_tower nr_gpu = config.nr_tower
assert self._nr_gpu > 1 assert nr_gpu > 1
self._raw_devices = ['/gpu:{}'.format(k) for k in config.tower] raw_devices = ['/gpu:{}'.format(k) for k in config.tower]
self._input_source = StagingInputWrapper(QueueInput(config.dataflow), self._raw_devices)
super(MultiGPUGANTrainer, self).__init__(config)
def _setup(self): # setup input
self._setup_input_source(self._input_source) input = StagingInputWrapper(QueueInput(config.dataflow), raw_devices)
devices = [LeastLoadedDeviceSetter(d, self._raw_devices) for d in self._raw_devices] model = config.model
cbs = input.setup(model.get_inputs_desc())
config.callbacks.extend(cbs)
def get_cost(): def get_cost():
self.model.build_graph(self._input_source) model.build_graph(input)
return [self.model.d_loss, self.model.g_loss] return [model.d_loss, model.g_loss]
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = MultiGPUTrainerBase.build_on_multi_tower( cost_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, get_cost, devices) config.tower, get_cost, devices)
# simply average the cost. might be faster to average the gradients # simply average the cost. It might get faster to average the gradients
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / self._nr_gpu) d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu)
g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / self._nr_gpu) g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / nr_gpu)
opt = self.model.get_optimizer() opt = model.get_optimizer()
# run one d_min after one g_min # run one d_min after one g_min
g_min = opt.minimize(g_loss, var_list=self.model.g_vars, g_min = opt.minimize(g_loss, var_list=model.g_vars,
colocate_gradients_with_ops=True, name='g_op') colocate_gradients_with_ops=True, name='g_op')
with tf.control_dependencies([g_min]): with tf.control_dependencies([g_min]):
d_min = opt.minimize(d_loss, var_list=self.model.d_vars, d_min = opt.minimize(d_loss, var_list=model.d_vars,
colocate_gradients_with_ops=True, name='d_op') colocate_gradients_with_ops=True, name='d_op')
self.train_op = d_min self.train_op = d_min
super(MultiGPUGANTrainer, self).__init__(config)
class RandomZData(DataFlow): class RandomZData(DataFlow):
......
...@@ -93,13 +93,12 @@ class InferenceRunnerBase(Callback): ...@@ -93,13 +93,12 @@ class InferenceRunnerBase(Callback):
tower_id = self.trainer.config.predict_tower[0] tower_id = self.trainer.config.predict_tower[0]
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0' device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0'
self._input_source.setup(self.trainer.model.get_inputs_desc()) cbs = self._input_source.setup(self.trainer.model.get_inputs_desc())
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self._tower_handle = self.trainer.predictor_factory.build( self._tower_handle = self.trainer.predictor_factory.build(
self._tower_name, device, self._input_source) self._tower_name, device, self._input_source)
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
cbs = self._input_source.get_callbacks()
self._hooks.extend([CallbackToHook(cb) for cb in cbs]) self._hooks.extend([CallbackToHook(cb) for cb in cbs])
def _before_train(self): def _before_train(self):
...@@ -173,7 +172,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -173,7 +172,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self._gpus = gpus self._gpus = gpus
def _setup_graph(self): def _setup_graph(self):
self._input_source.setup(self.trainer.model.get_inputs_desc()) cbs = self._input_source.setup(self.trainer.model.get_inputs_desc())
self._handles = [] self._handles = []
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for idx, t in enumerate(self._gpus): for idx, t in enumerate(self._gpus):
...@@ -186,7 +185,6 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -186,7 +185,6 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
# setup feeds and hooks # setup feeds and hooks
self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs] self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs]
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
cbs = self._input_source.get_callbacks()
self._hooks_parallel.extend([CallbackToHook(cb) for cb in cbs]) self._hooks_parallel.extend([CallbackToHook(cb) for cb in cbs])
class InferencerToHookDataParallel(InferencerToHook): class InferencerToHookDataParallel(InferencerToHook):
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import six import six
from ..utils.argtools import memoized
from ._utils import get_sublist_by_names, get_tensors_inputs from ._utils import get_sublist_by_names, get_tensors_inputs
__all__ = ['InputSource', 'remap_input_source'] __all__ = ['InputSource', 'remap_input_source']
...@@ -31,16 +32,25 @@ class InputSource(object): ...@@ -31,16 +32,25 @@ class InputSource(object):
""" """
Args: Args:
inputs_desc (list[InputDesc]): list of input desc inputs_desc (list[InputDesc]): list of input desc
Returns:
list[Callback]: extra callbacks needed by this InputSource.
""" """
self._setup(inputs_desc) self._setup(inputs_desc)
return self.get_callbacks()
def _setup(self, inputs_desc): def _setup(self, inputs_desc):
pass pass
@memoized
def get_callbacks(self): def get_callbacks(self):
""" """
An InputSource might need some extra maintainance during training,
which is done also through the Callback interface.
This method returns the Callbacks and the return value will be memoized.
Returns: Returns:
list[Callback]: extra callbacks required by this InputSource. list[Callback]: extra callbacks needed by this InputSource.
""" """
return self._get_callbacks() return self._get_callbacks()
......
...@@ -53,6 +53,12 @@ class PredictorFactory(object): ...@@ -53,6 +53,12 @@ class PredictorFactory(object):
self._names_built = {} self._names_built = {}
def build(self, tower_name, device, input=None): def build(self, tower_name, device, input=None):
"""
Args:
tower_name (str):
device(str):
input (InputSource): must be setup already. If None, will use InputDesc from the model.
"""
logger.info("Building predictor tower '{}' on device {} ...".format(tower_name, device)) logger.info("Building predictor tower '{}' on device {} ...".format(tower_name, device))
assert tower_name not in self._names_built assert tower_name not in self._names_built
......
...@@ -117,8 +117,7 @@ class Trainer(object): ...@@ -117,8 +117,7 @@ class Trainer(object):
""" """
Setup InputSource on this trainer. Setup InputSource on this trainer.
""" """
input_source.setup(self.model.get_inputs_desc()) cbs = input_source.setup(self.model.get_inputs_desc())
cbs = input_source.get_callbacks()
self.config.callbacks.extend(cbs) self.config.callbacks.extend(cbs)
def setup(self): def setup(self):
......
...@@ -202,7 +202,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase): ...@@ -202,7 +202,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
[Callback]: the callbacks to be added [Callback]: the callbacks to be added
""" """
input.setup(model.get_inputs_desc()) callbacks = input.setup(model.get_inputs_desc())
raw_devices = ['/gpu:{}'.format(k) for k in tower] raw_devices = ['/gpu:{}'.format(k) for k in tower]
if ps_device == 'gpu': if ps_device == 'gpu':
...@@ -226,7 +226,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase): ...@@ -226,7 +226,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
# grads = grad_list[0] # grads = grad_list[0]
train_op = model.get_optimizer().apply_gradients(grads, name='train_op') train_op = model.get_optimizer().apply_gradients(grads, name='train_op')
return train_op, input.get_callbacks() return train_op, callbacks
def _setup(self): def _setup(self):
self.train_op, cbs = SyncMultiGPUTrainerParameterServer.setup_graph( self.train_op, cbs = SyncMultiGPUTrainerParameterServer.setup_graph(
...@@ -294,7 +294,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase): ...@@ -294,7 +294,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
[Callback]: the callbacks to be added [Callback]: the callbacks to be added
""" """
input.setup(model.get_inputs_desc()) callbacks = input.setup(model.get_inputs_desc())
raw_devices = ['/gpu:{}'.format(k) for k in tower] raw_devices = ['/gpu:{}'.format(k) for k in tower]
...@@ -317,7 +317,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase): ...@@ -317,7 +317,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
cb = RunOp( cb = RunOp(
SyncMultiGPUTrainerReplicated.get_post_init_ops, SyncMultiGPUTrainerReplicated.get_post_init_ops,
run_before=True, run_as_trigger=True, verbose=True) run_before=True, run_as_trigger=True, verbose=True)
return train_op, input.get_callbacks() + [cb] return train_op, callbacks + [cb]
def _setup(self): def _setup(self):
self.train_op, cbs = SyncMultiGPUTrainerReplicated.setup_graph( self.train_op, cbs = SyncMultiGPUTrainerReplicated.setup_graph(
...@@ -379,7 +379,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase): ...@@ -379,7 +379,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase):
[Callback]: the callbacks to be added [Callback]: the callbacks to be added
""" """
input.setup(model.get_inputs_desc()) callbacks = input.setup(model.get_inputs_desc())
raw_devices = ['/gpu:{}'.format(k) for k in tower] raw_devices = ['/gpu:{}'.format(k) for k in tower]
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
...@@ -404,7 +404,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase): ...@@ -404,7 +404,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase):
# will call apply_gradients (therefore gradproc) multiple times # will call apply_gradients (therefore gradproc) multiple times
train_ops.append(opt.apply_gradients( train_ops.append(opt.apply_gradients(
grad_and_vars, name='apply_grad_{}'.format(i))) grad_and_vars, name='apply_grad_{}'.format(i)))
return tf.group(*train_ops, name='train_op'), input.get_callbacks() return tf.group(*train_ops, name='train_op'), callbacks
def _setup(self): def _setup(self):
self.train_op, cbs = AsyncMultiGPUTrainer.setup_graph( self.train_op, cbs = AsyncMultiGPUTrainer.setup_graph(
......
...@@ -53,8 +53,7 @@ class SimpleTrainer(Trainer): ...@@ -53,8 +53,7 @@ class SimpleTrainer(Trainer):
[Callback]: the callbacks to be added [Callback]: the callbacks to be added
""" """
input.setup(model.get_inputs_desc()) cbs = input.setup(model.get_inputs_desc())
cbs = input.get_callbacks()
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
model.build_graph(input) model.build_graph(input)
_, grads = model.get_cost_and_grad() _, grads = model.get_cost_and_grad()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment