Commit 66d5ce80 authored by Yuxin Wu's avatar Yuxin Wu

bugfix in async

parent d869aec8
......@@ -27,11 +27,11 @@ __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
'SyncMultiGPUTrainerParameterServer']
def apply_prefetch_policy(config):
def apply_prefetch_policy(config, use_stage=True):
if config.data is None and config.dataflow is not None:
config.data = QueueInput(config.dataflow)
config.dataflow = None
if len(config.tower) > 1:
if len(config.tower) > 1 and use_stage:
assert tf.test.is_gpu_available()
# seem to only improve on >1 GPUs
......@@ -204,10 +204,10 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
Data-parallel Multi-GPU trainer where each GPU contains a replicate of the
whole model. Each gradient update is broadcast and synced.
"""
def __init__(self, config):
apply_prefetch_policy(config)
self._input_source = config.data
logger.warn("Note that SyncMultiGPUTrainerReplicated doesn't support inference.")
super(SyncMultiGPUTrainerReplicated, self).__init__(config)
@staticmethod
......@@ -288,7 +288,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
``1.0/nr_tower``, to make Async and Sync Trainer have the same
effective learning rate.
"""
apply_prefetch_policy(config)
apply_prefetch_policy(config, use_stage=False)
logger.warn("Async training hasn't been well optimized. Sync training is even faster")
self._input_source = config.data
super(AsyncMultiGPUTrainer, self).__init__(config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment