Commit 66d5ce80 authored by Yuxin Wu's avatar Yuxin Wu

bugfix in async

parent d869aec8
...@@ -27,11 +27,11 @@ __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer', ...@@ -27,11 +27,11 @@ __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
'SyncMultiGPUTrainerParameterServer'] 'SyncMultiGPUTrainerParameterServer']
def apply_prefetch_policy(config): def apply_prefetch_policy(config, use_stage=True):
if config.data is None and config.dataflow is not None: if config.data is None and config.dataflow is not None:
config.data = QueueInput(config.dataflow) config.data = QueueInput(config.dataflow)
config.dataflow = None config.dataflow = None
if len(config.tower) > 1: if len(config.tower) > 1 and use_stage:
assert tf.test.is_gpu_available() assert tf.test.is_gpu_available()
# seem to only improve on >1 GPUs # seem to only improve on >1 GPUs
...@@ -204,10 +204,10 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -204,10 +204,10 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
Data-parallel Multi-GPU trainer where each GPU contains a replicate of the Data-parallel Multi-GPU trainer where each GPU contains a replicate of the
whole model. Each gradient update is broadcast and synced. whole model. Each gradient update is broadcast and synced.
""" """
def __init__(self, config): def __init__(self, config):
apply_prefetch_policy(config) apply_prefetch_policy(config)
self._input_source = config.data self._input_source = config.data
logger.warn("Note that SyncMultiGPUTrainerReplicated doesn't support inference.")
super(SyncMultiGPUTrainerReplicated, self).__init__(config) super(SyncMultiGPUTrainerReplicated, self).__init__(config)
@staticmethod @staticmethod
...@@ -288,7 +288,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase, ...@@ -288,7 +288,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
``1.0/nr_tower``, to make Async and Sync Trainer have the same ``1.0/nr_tower``, to make Async and Sync Trainer have the same
effective learning rate. effective learning rate.
""" """
apply_prefetch_policy(config) apply_prefetch_policy(config, use_stage=False)
logger.warn("Async training hasn't been well optimized. Sync training is even faster")
self._input_source = config.data self._input_source = config.data
super(AsyncMultiGPUTrainer, self).__init__(config) super(AsyncMultiGPUTrainer, self).__init__(config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment