Commit d50341b8 authored by Yuxin Wu's avatar Yuxin Wu

refactor around trainer and add some docs

parent 651a5aea
...@@ -77,6 +77,9 @@ class Trainer(object): ...@@ -77,6 +77,9 @@ class Trainer(object):
Can be overwritten by subclasses to exploit more Can be overwritten by subclasses to exploit more
parallelism among predictors. parallelism among predictors.
""" """
if len(self.config.predict_tower) > 1:
logger.warn(
"[Speed] Have set multiple predict_tower, but only have naive `get_predict_funcs` implementation")
return [self.get_predict_func(input_names, output_names) for k in range(n)] return [self.get_predict_func(input_names, output_names) for k in range(n)]
def trigger_epoch(self): def trigger_epoch(self):
......
...@@ -43,7 +43,7 @@ class TrainConfig(object): ...@@ -43,7 +43,7 @@ class TrainConfig(object):
max_epoch (int): maximum number of epoch to run training. max_epoch (int): maximum number of epoch to run training.
nr_tower (int): number of training towers. nr_tower (int): number of training towers.
tower (list of int): list of training towers in relative id. tower (list of int): list of training towers in relative id.
predict_tower (list of int): list of prediction towers in their relative gpu id. predict_tower (list of int): list of prediction towers in their relative gpu id. Use -1 for cpu.
""" """
# TODO type checker decorator # TODO type checker decorator
......
...@@ -92,7 +92,7 @@ class SimpleFeedfreeTrainer( ...@@ -92,7 +92,7 @@ class SimpleFeedfreeTrainer(
self._input_method = config.data self._input_method = config.data
assert isinstance(self._input_method, FeedfreeInput), self._input_method assert isinstance(self._input_method, FeedfreeInput), self._input_method
super(SimpleFeedfreeTrainer, self).__init__(config) super(SimpleFeedfreeTrainer, self).__init__(config)
self._setup_predictor_factory(config.predict_tower) self._setup_predictor_factory()
assert len(self.config.tower) == 1, \ assert len(self.config.tower) == 1, \
"SimpleFeedfreeTrainer doesn't support multigpu!" "SimpleFeedfreeTrainer doesn't support multigpu!"
...@@ -111,22 +111,23 @@ class SimpleFeedfreeTrainer( ...@@ -111,22 +111,23 @@ class SimpleFeedfreeTrainer(
class QueueInputTrainer(SimpleFeedfreeTrainer): class QueueInputTrainer(SimpleFeedfreeTrainer):
""" """
A trainer which automatically wraps ``config.dataflow`` A trainer which automatically wraps ``config.dataflow`` by a
:class:`QueueInput`.
""" """
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config, input_queue=None, predict_tower=None):
""" """
Single tower Trainer, takes input from a queue Single tower Trainer, takes input from a queue
:param config: a `TrainConfig` instance. config.dataflow must exist Args:
:param input_queue: a `tf.QueueBase` instance config(TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
:param predict_tower: list of gpu relative idx to run prediction. default to be [0]. input_queue(tf.QueueBase): an input queue. Defaults to the
Use -1 for cpu. :class:`QueueInput` default.
""" """
config.data = QueueInput(config.dataflow, input_queue) config.data = QueueInput(config.dataflow, input_queue)
if predict_tower is not None: if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. " logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig.predict_tower instead!") "Use TrainConfig(predict_tower=...) instead!")
config.predict_tower = predict_tower config.predict_tower = predict_tower
assert len(config.tower) == 1, \ assert len(config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead." "QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
......
...@@ -19,12 +19,17 @@ __all__ = ['InputData', 'QueueInput', 'FeedfreeInput', 'TensorInput', ...@@ -19,12 +19,17 @@ __all__ = ['InputData', 'QueueInput', 'FeedfreeInput', 'TensorInput',
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class InputData(object): class InputData(object):
""" Base class for the abstract InputData. """
pass pass
class FeedInput(InputData): class FeedInput(InputData):
""" Input by iterating over a DataFlow and feed datapoints. """
def __init__(self, ds): def __init__(self, ds):
"""
Args:
ds (DataFlow): the input DataFlow.
"""
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
self.ds = ds self.ds = ds
...@@ -44,8 +49,14 @@ class FeedInput(InputData): ...@@ -44,8 +49,14 @@ class FeedInput(InputData):
class FeedfreeInput(InputData): class FeedfreeInput(InputData):
""" Abstract base for input without feed,
e.g. by queue or other operations. """
def get_input_tensors(self): def get_input_tensors(self):
"""
Returns:
list: A list of tensors corresponding to the inputs of the model.
"""
return self._get_input_tensors() return self._get_input_tensors()
@abstractmethod @abstractmethod
...@@ -100,12 +111,14 @@ class EnqueueThread(threading.Thread): ...@@ -100,12 +111,14 @@ class EnqueueThread(threading.Thread):
class QueueInput(FeedfreeInput): class QueueInput(FeedfreeInput):
""" Input by enqueueing datapoints from a DataFlow to a TF queue, and dequeue
tensors to the graph. """
def __init__(self, ds, queue=None): def __init__(self, ds, queue=None):
""" """
:param ds: a `DataFlow` instance Args:
:param queue: a `tf.QueueBase` instance to be used to buffer datapoints. ds(DataFlow): the input DataFlow.
Defaults to a FIFO queue of size 50. queue (tf.QueueBase): Defaults to a FIFO queue of size 50.
""" """
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
self.queue = queue self.queue = queue
...@@ -142,11 +155,10 @@ class QueueInput(FeedfreeInput): ...@@ -142,11 +155,10 @@ class QueueInput(FeedfreeInput):
return ret return ret
class DummyConstantInput(QueueInput): class DummyConstantInput(FeedfreeInput):
""" only for debugging performance issues """ """ Input some constant variables. Only for debugging performance issues """
def __init__(self, ds, shapes): def __init__(self, shapes):
super(DummyConstantInput, self).__init__(ds)
self.shapes = shapes self.shapes = shapes
logger.warn("Using dummy input for debug!") logger.warn("Using dummy input for debug!")
...@@ -163,8 +175,15 @@ class DummyConstantInput(QueueInput): ...@@ -163,8 +175,15 @@ class DummyConstantInput(QueueInput):
class TensorInput(FeedfreeInput): class TensorInput(FeedfreeInput):
""" Input from a list of tensors, e.g. a TF data reading pipeline. """
def __init__(self, get_tensor_fn, size=None): def __init__(self, get_tensor_fn, size=None):
"""
Args:
get_tensor_fn: a function which returns a list of input tensors
when called.
size(int): size of this input. Use None to leave it undefined.
"""
self.get_tensor_fn = get_tensor_fn self.get_tensor_fn = get_tensor_fn
self._size = size self._size = size
......
...@@ -21,7 +21,7 @@ from .trainer import MultiPredictorTowerTrainer ...@@ -21,7 +21,7 @@ from .trainer import MultiPredictorTowerTrainer
from .feedfree import SingleCostFeedfreeTrainer from .feedfree import SingleCostFeedfreeTrainer
from .input_data import QueueInput from .input_data import QueueInput
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer'] __all__ = ['SyncMultiGPUTrainer', 'AsyncMultiGPUTrainer']
class MultiGPUTrainer(Trainer): class MultiGPUTrainer(Trainer):
...@@ -51,8 +51,16 @@ class MultiGPUTrainer(Trainer): ...@@ -51,8 +51,16 @@ class MultiGPUTrainer(Trainer):
class SyncMultiGPUTrainer(MultiGPUTrainer, class SyncMultiGPUTrainer(MultiGPUTrainer,
SingleCostFeedfreeTrainer, SingleCostFeedfreeTrainer,
MultiPredictorTowerTrainer): MultiPredictorTowerTrainer):
"""
A multi-tower multi-GPU trainer which synchronoizes the gradients computed
from each tower and averages them.
"""
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config, input_queue=None, predict_tower=None):
"""
Args:
config, input_queue: same as in :class:`QueueInputTrainer`.
"""
if config.dataflow is not None: if config.dataflow is not None:
self._input_method = QueueInput(config.dataflow, input_queue) self._input_method = QueueInput(config.dataflow, input_queue)
else: else:
...@@ -65,7 +73,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -65,7 +73,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
config.predict_tower = predict_tower config.predict_tower = predict_tower
super(SyncMultiGPUTrainer, self).__init__(config) super(SyncMultiGPUTrainer, self).__init__(config)
self._setup_predictor_factory(config.predict_tower) self._setup_predictor_factory()
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU." assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
assert tf.test.is_gpu_available() assert tf.test.is_gpu_available()
...@@ -117,11 +125,22 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -117,11 +125,22 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
class AsyncMultiGPUTrainer(MultiGPUTrainer, class AsyncMultiGPUTrainer(MultiGPUTrainer,
SingleCostFeedfreeTrainer, SingleCostFeedfreeTrainer,
MultiPredictorTowerTrainer): MultiPredictorTowerTrainer):
"""
A multi-tower multi-GPU trainer where each tower independently
asynchronously updates the model without locking.
"""
def __init__(self, config, def __init__(self, config,
input_queue=None, input_queue=None,
average_gradient=True, scale_gradient=True,
predict_tower=None): predict_tower=None):
"""
Args:
config, input_queue: same as in :class:`QueueInputTrainer`.
scale_gradient (bool): if True, will scale each gradient by
``1.0/nr_tower``, to make Async and Sync Trainer have the same
effective learning rate.
"""
if config.dataflow is not None: if config.dataflow is not None:
self._input_method = QueueInput(config.dataflow, input_queue) self._input_method = QueueInput(config.dataflow, input_queue)
else: else:
...@@ -134,8 +153,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -134,8 +153,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
"Use TrainConfig.predict_tower instead!") "Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower config.predict_tower = predict_tower
self._setup_predictor_factory(config.predict_tower) self._setup_predictor_factory()
self._average_gradient = average_gradient self._scale_gradient = scale_gradient
assert tf.test.is_gpu_available() assert tf.test.is_gpu_available()
def _setup(self): def _setup(self):
...@@ -143,7 +162,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -143,7 +162,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
grad_list = MultiGPUTrainer._multi_tower_grads( grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1]) self.config.tower, lambda: self._get_cost_and_grad()[1])
gradprocs = self.model.get_gradient_processor() gradprocs = self.model.get_gradient_processor()
if self._average_gradient and self.config.nr_tower > 1: if self._scale_gradient and self.config.nr_tower > 1:
# pretend to average the grads, in order to make async and # pretend to average the grads, in order to make async and
# sync have consistent effective learning rate # sync have consistent effective learning rate
gradprocs.insert(0, ScaleGradient(('.*', 1.0 / self.config.nr_tower), log=False)) gradprocs.insert(0, ScaleGradient(('.*', 1.0 / self.config.nr_tower), log=False))
......
...@@ -54,9 +54,14 @@ class PredictorFactory(object): ...@@ -54,9 +54,14 @@ class PredictorFactory(object):
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
""" A naive demo trainer """ """ A naive demo trainer which iterates over a DataFlow and feed into the
graph. It's not efficient compared to QueueInputTrainer or others."""
def __init__(self, config): def __init__(self, config):
"""
Args:
config (TrainConfig): the training config.
"""
super(SimpleTrainer, self).__init__(config) super(SimpleTrainer, self).__init__(config)
self._predictor_factory = PredictorFactory(self.sess, self.model, [0]) self._predictor_factory = PredictorFactory(self.sess, self.model, [0])
if config.dataflow is None: if config.dataflow is None:
...@@ -66,6 +71,7 @@ class SimpleTrainer(Trainer): ...@@ -66,6 +71,7 @@ class SimpleTrainer(Trainer):
self._input_method = FeedInput(config.dataflow) self._input_method = FeedInput(config.dataflow)
def run_step(self): def run_step(self):
""" Feed data into the graph and run the updates. """
feed = self._input_method.next_feed() feed = self._input_method.next_feed()
self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None
...@@ -99,11 +105,10 @@ class SimpleTrainer(Trainer): ...@@ -99,11 +105,10 @@ class SimpleTrainer(Trainer):
class MultiPredictorTowerTrainer(Trainer): class MultiPredictorTowerTrainer(Trainer):
""" A trainer with possibly multiple prediction tower """ """ A trainer with possibly multiple prediction tower """
def _setup_predictor_factory(self, predict_tower): def _setup_predictor_factory(self):
# by default, use the first training gpu for prediction # by default, use the first training gpu for prediction
predict_tower = predict_tower or [0]
self._predictor_factory = PredictorFactory( self._predictor_factory = PredictorFactory(
self.sess, self.model, predict_tower) self.sess, self.model, self.config.predict_tower)
def get_predict_func(self, input_names, output_names, tower=0): def get_predict_func(self, input_names, output_names, tower=0):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment