Commit fb43cf03 authored by Yuxin Wu's avatar Yuxin Wu

Use StageArea by default in SyncMultiGPUTrainer. fix #140

parent ba2758b3
...@@ -41,10 +41,9 @@ It's Yet Another TF wrapper, but different in: ...@@ -41,10 +41,9 @@ It's Yet Another TF wrapper, but different in:
3. Focus on training speed. 3. Focus on training speed.
+ Tensorpack trainer is almost always faster than `feed_dict` based wrappers. + Tensorpack trainer is almost always faster than `feed_dict` based wrappers.
Even on a small CNN example, the training runs [2x faster](https://gist.github.com/ppwwyyxx/8d95da79f8d97036a7d67c2416c851b6) than the equivalent Keras code. Even on a small CNN example, the training runs [2x faster](https://gist.github.com/ppwwyyxx/8d95da79f8d97036a7d67c2416c851b6) than the equivalent Keras code.
More improvements to come later.
+ Data-Parallel Multi-GPU training is off-the-shelf to use. + Data-Parallel Multi-GPU training is off-the-shelf to use. For <=4 GPUs it is as fast as [tensorflow/benchmarks](https://github.com/tensorflow/benchmarks).
You can also define your own trainer for different style of training (e.g. GAN) without losing the efficiency. More improvements to come later.
4. Interface of extensible __Callbacks__. 4. Interface of extensible __Callbacks__.
Write a callback to implement everything you want to do apart from the training iterations, and Write a callback to implement everything you want to do apart from the training iterations, and
......
...@@ -17,7 +17,7 @@ from ..tfutils.gradproc import FilterNoneGrad, ScaleGradient ...@@ -17,7 +17,7 @@ from ..tfutils.gradproc import FilterNoneGrad, ScaleGradient
from .base import Trainer from .base import Trainer
from .feedfree import SingleCostFeedfreeTrainer from .feedfree import SingleCostFeedfreeTrainer
from .input_data import QueueInput from .input_data import QueueInput, StagingInputWrapper
__all__ = ['SyncMultiGPUTrainer', 'AsyncMultiGPUTrainer'] __all__ = ['SyncMultiGPUTrainer', 'AsyncMultiGPUTrainer']
...@@ -76,12 +76,16 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -76,12 +76,16 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
else: else:
assert input_queue is None, input_queue assert input_queue is None, input_queue
self._input_method = config.data self._input_method = config.data
# assert isinstance(self._input_method, QueueInput)
super(SyncMultiGPUTrainer, self).__init__(config)
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one tower." assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one tower."
if len(config.tower) > 1: if len(config.tower) > 1:
assert tf.test.is_gpu_available() assert tf.test.is_gpu_available()
if not isinstance(self._input_method, StagingInputWrapper):
devices = ['/gpu:{}'.format(k) for k in config.tower]
self._input_method = StagingInputWrapper(self._input_method, devices)
super(SyncMultiGPUTrainer, self).__init__(config)
self.average_cost = average_cost self.average_cost = average_cost
@staticmethod @staticmethod
...@@ -161,7 +165,6 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -161,7 +165,6 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
else: else:
assert input_queue is None, input_queue assert input_queue is None, input_queue
self._input_method = config.data self._input_method = config.data
assert isinstance(self._input_method, QueueInput)
super(AsyncMultiGPUTrainer, self).__init__(config) super(AsyncMultiGPUTrainer, self).__init__(config)
self._scale_gradient = scale_gradient self._scale_gradient = scale_gradient
...@@ -194,7 +197,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -194,7 +197,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
train_op = self.config.optimizer.apply_gradients(grad_list[k]) train_op = self.config.optimizer.apply_gradients(grad_list[k])
def f(op=train_op): # avoid late-binding def f(op=train_op): # avoid late-binding
self.sess.run([op]) self.sess.run([op]) # TODO this won't work with StageInput
next(self.async_step_counter) # atomic due to GIL next(self.async_step_counter) # atomic due to GIL
th = LoopThread(f) th = LoopThread(f)
th.name = "AsyncLoopThread-{}".format(k) th.name = "AsyncLoopThread-{}".format(k)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment