Commit d50341b8 authored by Yuxin Wu's avatar Yuxin Wu

refactor around trainer and add some docs

parent 651a5aea
......@@ -77,6 +77,9 @@ class Trainer(object):
Can be overwritten by subclasses to exploit more
parallelism among predictors.
"""
if len(self.config.predict_tower) > 1:
logger.warn(
"[Speed] Have set multiple predict_tower, but only have naive `get_predict_funcs` implementation")
return [self.get_predict_func(input_names, output_names) for k in range(n)]
def trigger_epoch(self):
......
......@@ -43,7 +43,7 @@ class TrainConfig(object):
max_epoch (int): maximum number of epoch to run training.
nr_tower (int): number of training towers.
tower (list of int): list of training towers in relative id.
predict_tower (list of int): list of prediction towers in their relative gpu id.
predict_tower (list of int): list of prediction towers in their relative gpu id. Use -1 for cpu.
"""
# TODO type checker decorator
......
......@@ -92,7 +92,7 @@ class SimpleFeedfreeTrainer(
self._input_method = config.data
assert isinstance(self._input_method, FeedfreeInput), self._input_method
super(SimpleFeedfreeTrainer, self).__init__(config)
self._setup_predictor_factory(config.predict_tower)
self._setup_predictor_factory()
assert len(self.config.tower) == 1, \
"SimpleFeedfreeTrainer doesn't support multigpu!"
......@@ -111,22 +111,23 @@ class SimpleFeedfreeTrainer(
class QueueInputTrainer(SimpleFeedfreeTrainer):
"""
A trainer which automatically wraps ``config.dataflow``
A trainer which automatically wraps ``config.dataflow`` by a
:class:`QueueInput`.
"""
def __init__(self, config, input_queue=None, predict_tower=None):
"""
Single tower Trainer, takes input from a queue
:param config: a `TrainConfig` instance. config.dataflow must exist
:param input_queue: a `tf.QueueBase` instance
:param predict_tower: list of gpu relative idx to run prediction. default to be [0].
Use -1 for cpu.
Args:
config(TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
input_queue(tf.QueueBase): an input queue. Defaults to the
:class:`QueueInput` default.
"""
config.data = QueueInput(config.dataflow, input_queue)
if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig.predict_tower instead!")
"Use TrainConfig(predict_tower=...) instead!")
config.predict_tower = predict_tower
assert len(config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
......
......@@ -19,12 +19,17 @@ __all__ = ['InputData', 'QueueInput', 'FeedfreeInput', 'TensorInput',
@six.add_metaclass(ABCMeta)
class InputData(object):
""" Base class for the abstract InputData. """
pass
class FeedInput(InputData):
""" Input by iterating over a DataFlow and feed datapoints. """
def __init__(self, ds):
"""
Args:
ds (DataFlow): the input DataFlow.
"""
assert isinstance(ds, DataFlow), ds
self.ds = ds
......@@ -44,8 +49,14 @@ class FeedInput(InputData):
class FeedfreeInput(InputData):
""" Abstract base for input without feed,
e.g. by queue or other operations. """
def get_input_tensors(self):
"""
Returns:
list: A list of tensors corresponding to the inputs of the model.
"""
return self._get_input_tensors()
@abstractmethod
......@@ -100,12 +111,14 @@ class EnqueueThread(threading.Thread):
class QueueInput(FeedfreeInput):
""" Input by enqueueing datapoints from a DataFlow to a TF queue, and dequeue
tensors to the graph. """
def __init__(self, ds, queue=None):
"""
:param ds: a `DataFlow` instance
:param queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Defaults to a FIFO queue of size 50.
Args:
ds(DataFlow): the input DataFlow.
queue (tf.QueueBase): Defaults to a FIFO queue of size 50.
"""
assert isinstance(ds, DataFlow), ds
self.queue = queue
......@@ -142,11 +155,10 @@ class QueueInput(FeedfreeInput):
return ret
class DummyConstantInput(QueueInput):
""" only for debugging performance issues """
class DummyConstantInput(FeedfreeInput):
""" Input some constant variables. Only for debugging performance issues """
def __init__(self, ds, shapes):
super(DummyConstantInput, self).__init__(ds)
def __init__(self, shapes):
self.shapes = shapes
logger.warn("Using dummy input for debug!")
......@@ -163,8 +175,15 @@ class DummyConstantInput(QueueInput):
class TensorInput(FeedfreeInput):
""" Input from a list of tensors, e.g. a TF data reading pipeline. """
def __init__(self, get_tensor_fn, size=None):
"""
Args:
get_tensor_fn: a function which returns a list of input tensors
when called.
size(int): size of this input. Use None to leave it undefined.
"""
self.get_tensor_fn = get_tensor_fn
self._size = size
......
......@@ -21,7 +21,7 @@ from .trainer import MultiPredictorTowerTrainer
from .feedfree import SingleCostFeedfreeTrainer
from .input_data import QueueInput
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
__all__ = ['SyncMultiGPUTrainer', 'AsyncMultiGPUTrainer']
class MultiGPUTrainer(Trainer):
......@@ -51,8 +51,16 @@ class MultiGPUTrainer(Trainer):
class SyncMultiGPUTrainer(MultiGPUTrainer,
SingleCostFeedfreeTrainer,
MultiPredictorTowerTrainer):
"""
A multi-tower multi-GPU trainer which synchronoizes the gradients computed
from each tower and averages them.
"""
def __init__(self, config, input_queue=None, predict_tower=None):
"""
Args:
config, input_queue: same as in :class:`QueueInputTrainer`.
"""
if config.dataflow is not None:
self._input_method = QueueInput(config.dataflow, input_queue)
else:
......@@ -65,7 +73,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
config.predict_tower = predict_tower
super(SyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(config.predict_tower)
self._setup_predictor_factory()
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
assert tf.test.is_gpu_available()
......@@ -117,11 +125,22 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
class AsyncMultiGPUTrainer(MultiGPUTrainer,
SingleCostFeedfreeTrainer,
MultiPredictorTowerTrainer):
"""
A multi-tower multi-GPU trainer where each tower independently
asynchronously updates the model without locking.
"""
def __init__(self, config,
input_queue=None,
average_gradient=True,
scale_gradient=True,
predict_tower=None):
"""
Args:
config, input_queue: same as in :class:`QueueInputTrainer`.
scale_gradient (bool): if True, will scale each gradient by
``1.0/nr_tower``, to make Async and Sync Trainer have the same
effective learning rate.
"""
if config.dataflow is not None:
self._input_method = QueueInput(config.dataflow, input_queue)
else:
......@@ -134,8 +153,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
"Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower
self._setup_predictor_factory(config.predict_tower)
self._average_gradient = average_gradient
self._setup_predictor_factory()
self._scale_gradient = scale_gradient
assert tf.test.is_gpu_available()
def _setup(self):
......@@ -143,7 +162,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1])
gradprocs = self.model.get_gradient_processor()
if self._average_gradient and self.config.nr_tower > 1:
if self._scale_gradient and self.config.nr_tower > 1:
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
gradprocs.insert(0, ScaleGradient(('.*', 1.0 / self.config.nr_tower), log=False))
......
......@@ -54,9 +54,14 @@ class PredictorFactory(object):
class SimpleTrainer(Trainer):
""" A naive demo trainer """
""" A naive demo trainer which iterates over a DataFlow and feed into the
graph. It's not efficient compared to QueueInputTrainer or others."""
def __init__(self, config):
"""
Args:
config (TrainConfig): the training config.
"""
super(SimpleTrainer, self).__init__(config)
self._predictor_factory = PredictorFactory(self.sess, self.model, [0])
if config.dataflow is None:
......@@ -66,6 +71,7 @@ class SimpleTrainer(Trainer):
self._input_method = FeedInput(config.dataflow)
def run_step(self):
""" Feed data into the graph and run the updates. """
feed = self._input_method.next_feed()
self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None
......@@ -99,11 +105,10 @@ class SimpleTrainer(Trainer):
class MultiPredictorTowerTrainer(Trainer):
""" A trainer with possibly multiple prediction tower """
def _setup_predictor_factory(self, predict_tower):
def _setup_predictor_factory(self):
# by default, use the first training gpu for prediction
predict_tower = predict_tower or [0]
self._predictor_factory = PredictorFactory(
self.sess, self.model, predict_tower)
self.sess, self.model, self.config.predict_tower)
def get_predict_func(self, input_names, output_names, tower=0):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment