Commit 8419ee3f authored by Yuxin Wu's avatar Yuxin Wu

Let InputSource.setup() returns the callbacks, to simplify trainer implementations

parent 398cb933
......@@ -66,21 +66,24 @@ class GANModelDesc(ModelDescBase):
class GANTrainer(Trainer):
def __init__(self, config):
self._input_source = QueueInput(config.dataflow)
super(GANTrainer, self).__init__(config)
input = QueueInput(config.dataflow)
model = config.model
cbs = input.setup(model.get_inputs_desc())
config.callbacks.extend(cbs)
def _setup(self):
self._setup_input_source(self._input_source)
with TowerContext('', is_training=True):
self.model.build_graph(self._input_source)
opt = self.model.get_optimizer()
model.build_graph(input)
opt = model.get_optimizer()
# by default, run one d_min after one g_min
g_min = opt.minimize(self.model.g_loss, var_list=self.model.g_vars, name='g_op')
g_min = opt.minimize(model.g_loss, var_list=model.g_vars, name='g_op')
with tf.control_dependencies([g_min]):
d_min = opt.minimize(self.model.d_loss, var_list=self.model.d_vars, name='d_op')
d_min = opt.minimize(model.d_loss, var_list=model.d_vars, name='d_op')
self.train_op = d_min
super(GANTrainer, self).__init__(config)
class SeparateGANTrainer(Trainer):
""" A GAN trainer which runs two optimization ops with a certain ratio, one in each step. """
......@@ -90,30 +93,31 @@ class SeparateGANTrainer(Trainer):
d_period(int): period of each d_opt run
g_period(int): period of each g_opt run
"""
self._input_source = QueueInput(config.dataflow)
self._d_period = int(d_period)
self._g_period = int(g_period)
assert min(d_period, g_period) == 1
super(SeparateGANTrainer, self).__init__(config)
def _setup(self):
self._setup_input_source(self._input_source)
input = QueueInput(config.dataflow)
model = config.model
cbs = input.setup(model.get_inputs_desc())
config.callbacks.extend(cbs)
with TowerContext('', is_training=True):
self.model.build_graph(self._input_source)
model.build_graph(input)
opt = self.model.get_optimizer()
opt = model.get_optimizer()
self.d_min = opt.minimize(
self.model.d_loss, var_list=self.model.d_vars, name='d_min')
model.d_loss, var_list=model.d_vars, name='d_min')
self.g_min = opt.minimize(
self.model.g_loss, var_list=self.model.g_vars, name='g_min')
self._cnt = 1
model.g_loss, var_list=model.g_vars, name='g_min')
super(SeparateGANTrainer, self).__init__(config)
def run_step(self):
if self._cnt % (self._d_period) == 0:
if self.global_step % (self._d_period) == 0:
self.hooked_sess.run(self.d_min)
if self._cnt % (self._g_period) == 0:
if self.global_step % (self._g_period) == 0:
self.hooked_sess.run(self.g_min)
self._cnt += 1
class MultiGPUGANTrainer(Trainer):
......@@ -121,33 +125,35 @@ class MultiGPUGANTrainer(Trainer):
A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support.
"""
def __init__(self, config):
self._nr_gpu = config.nr_tower
assert self._nr_gpu > 1
self._raw_devices = ['/gpu:{}'.format(k) for k in config.tower]
self._input_source = StagingInputWrapper(QueueInput(config.dataflow), self._raw_devices)
super(MultiGPUGANTrainer, self).__init__(config)
nr_gpu = config.nr_tower
assert nr_gpu > 1
raw_devices = ['/gpu:{}'.format(k) for k in config.tower]
def _setup(self):
self._setup_input_source(self._input_source)
devices = [LeastLoadedDeviceSetter(d, self._raw_devices) for d in self._raw_devices]
# setup input
input = StagingInputWrapper(QueueInput(config.dataflow), raw_devices)
model = config.model
cbs = input.setup(model.get_inputs_desc())
config.callbacks.extend(cbs)
def get_cost():
self.model.build_graph(self._input_source)
return [self.model.d_loss, self.model.g_loss]
model.build_graph(input)
return [model.d_loss, model.g_loss]
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, get_cost, devices)
# simply average the cost. might be faster to average the gradients
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / self._nr_gpu)
g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / self._nr_gpu)
config.tower, get_cost, devices)
# simply average the cost. It might get faster to average the gradients
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu)
g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / nr_gpu)
opt = self.model.get_optimizer()
opt = model.get_optimizer()
# run one d_min after one g_min
g_min = opt.minimize(g_loss, var_list=self.model.g_vars,
g_min = opt.minimize(g_loss, var_list=model.g_vars,
colocate_gradients_with_ops=True, name='g_op')
with tf.control_dependencies([g_min]):
d_min = opt.minimize(d_loss, var_list=self.model.d_vars,
d_min = opt.minimize(d_loss, var_list=model.d_vars,
colocate_gradients_with_ops=True, name='d_op')
self.train_op = d_min
super(MultiGPUGANTrainer, self).__init__(config)
class RandomZData(DataFlow):
......
......@@ -93,13 +93,12 @@ class InferenceRunnerBase(Callback):
tower_id = self.trainer.config.predict_tower[0]
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0'
self._input_source.setup(self.trainer.model.get_inputs_desc())
cbs = self._input_source.setup(self.trainer.model.get_inputs_desc())
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self._tower_handle = self.trainer.predictor_factory.build(
self._tower_name, device, self._input_source)
self._hooks = [self._build_hook(inf) for inf in self.infs]
cbs = self._input_source.get_callbacks()
self._hooks.extend([CallbackToHook(cb) for cb in cbs])
def _before_train(self):
......@@ -173,7 +172,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self._gpus = gpus
def _setup_graph(self):
self._input_source.setup(self.trainer.model.get_inputs_desc())
cbs = self._input_source.setup(self.trainer.model.get_inputs_desc())
self._handles = []
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for idx, t in enumerate(self._gpus):
......@@ -186,7 +185,6 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
# setup feeds and hooks
self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs]
self._hooks = [self._build_hook(inf) for inf in self.infs]
cbs = self._input_source.get_callbacks()
self._hooks_parallel.extend([CallbackToHook(cb) for cb in cbs])
class InferencerToHookDataParallel(InferencerToHook):
......
......@@ -5,6 +5,7 @@
from abc import ABCMeta, abstractmethod
import six
from ..utils.argtools import memoized
from ._utils import get_sublist_by_names, get_tensors_inputs
__all__ = ['InputSource', 'remap_input_source']
......@@ -31,16 +32,25 @@ class InputSource(object):
"""
Args:
inputs_desc (list[InputDesc]): list of input desc
Returns:
list[Callback]: extra callbacks needed by this InputSource.
"""
self._setup(inputs_desc)
return self.get_callbacks()
def _setup(self, inputs_desc):
pass
@memoized
def get_callbacks(self):
"""
An InputSource might need some extra maintainance during training,
which is done also through the Callback interface.
This method returns the Callbacks and the return value will be memoized.
Returns:
list[Callback]: extra callbacks required by this InputSource.
list[Callback]: extra callbacks needed by this InputSource.
"""
return self._get_callbacks()
......
......@@ -53,6 +53,12 @@ class PredictorFactory(object):
self._names_built = {}
def build(self, tower_name, device, input=None):
"""
Args:
tower_name (str):
device(str):
input (InputSource): must be setup already. If None, will use InputDesc from the model.
"""
logger.info("Building predictor tower '{}' on device {} ...".format(tower_name, device))
assert tower_name not in self._names_built
......
......@@ -117,8 +117,7 @@ class Trainer(object):
"""
Setup InputSource on this trainer.
"""
input_source.setup(self.model.get_inputs_desc())
cbs = input_source.get_callbacks()
cbs = input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs)
def setup(self):
......
......@@ -202,7 +202,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
[Callback]: the callbacks to be added
"""
input.setup(model.get_inputs_desc())
callbacks = input.setup(model.get_inputs_desc())
raw_devices = ['/gpu:{}'.format(k) for k in tower]
if ps_device == 'gpu':
......@@ -226,7 +226,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
# grads = grad_list[0]
train_op = model.get_optimizer().apply_gradients(grads, name='train_op')
return train_op, input.get_callbacks()
return train_op, callbacks
def _setup(self):
self.train_op, cbs = SyncMultiGPUTrainerParameterServer.setup_graph(
......@@ -294,7 +294,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
[Callback]: the callbacks to be added
"""
input.setup(model.get_inputs_desc())
callbacks = input.setup(model.get_inputs_desc())
raw_devices = ['/gpu:{}'.format(k) for k in tower]
......@@ -317,7 +317,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
cb = RunOp(
SyncMultiGPUTrainerReplicated.get_post_init_ops,
run_before=True, run_as_trigger=True, verbose=True)
return train_op, input.get_callbacks() + [cb]
return train_op, callbacks + [cb]
def _setup(self):
self.train_op, cbs = SyncMultiGPUTrainerReplicated.setup_graph(
......@@ -379,7 +379,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase):
[Callback]: the callbacks to be added
"""
input.setup(model.get_inputs_desc())
callbacks = input.setup(model.get_inputs_desc())
raw_devices = ['/gpu:{}'.format(k) for k in tower]
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
......@@ -404,7 +404,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase):
# will call apply_gradients (therefore gradproc) multiple times
train_ops.append(opt.apply_gradients(
grad_and_vars, name='apply_grad_{}'.format(i)))
return tf.group(*train_ops, name='train_op'), input.get_callbacks()
return tf.group(*train_ops, name='train_op'), callbacks
def _setup(self):
self.train_op, cbs = AsyncMultiGPUTrainer.setup_graph(
......
......@@ -53,8 +53,7 @@ class SimpleTrainer(Trainer):
[Callback]: the callbacks to be added
"""
input.setup(model.get_inputs_desc())
cbs = input.get_callbacks()
cbs = input.setup(model.get_inputs_desc())
with TowerContext('', is_training=True):
model.build_graph(input)
_, grads = model.get_cost_and_grad()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment