Commit 661abb69 authored by Yuxin Wu's avatar Yuxin Wu

GPU prefetch as an option

parent 1cd986d3
...@@ -30,11 +30,12 @@ def _check_tf_version(): ...@@ -30,11 +30,12 @@ def _check_tf_version():
"TF version {} is too old to run multi GPU training!".format(tf.VERSION) "TF version {} is too old to run multi GPU training!".format(tf.VERSION)
def apply_prefetch_policy(config, use_stage=True): def apply_prefetch_policy(config, gpu_prefetch=True):
if config.data is None and config.dataflow is not None: if config.data is None and config.dataflow is not None:
# always use Queue prefetch
config.data = QueueInput(config.dataflow) config.data = QueueInput(config.dataflow)
config.dataflow = None config.dataflow = None
if len(config.tower) > 1 and use_stage: if len(config.tower) > 1 and gpu_prefetch:
assert tf.test.is_gpu_available() assert tf.test.is_gpu_available()
# seem to only improve on >1 GPUs # seem to only improve on >1 GPUs
...@@ -146,13 +147,14 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree ...@@ -146,13 +147,14 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
GPUs or on CPU. GPUs or on CPU.
""" """
def __init__(self, config, ps_device='gpu'): def __init__(self, config, ps_device='gpu', gpu_prefetch=True):
""" """
Args: Args:
config(TrainConfig): config(TrainConfig):
ps_device: either 'gpu' or 'cpu', where variables are stored. ps_device: either 'gpu' or 'cpu', where variables are stored.
gpu_prefetch(bool): whether to prefetch the data to each GPU. Usually improve performance.
""" """
apply_prefetch_policy(config) apply_prefetch_policy(config, gpu_prefetch)
self._input_source = config.data self._input_source = config.data
assert ps_device in ['gpu', 'cpu'], ps_device assert ps_device in ['gpu', 'cpu'], ps_device
...@@ -219,8 +221,12 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -219,8 +221,12 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
Data-parallel Multi-GPU trainer where each GPU contains a replicate of the Data-parallel Multi-GPU trainer where each GPU contains a replicate of the
whole model. Each gradient update is broadcast and synced. whole model. Each gradient update is broadcast and synced.
""" """
def __init__(self, config): def __init__(self, config, gpu_prefetch=True):
apply_prefetch_policy(config) """
Args:
config, gpu_prefetch: same as in :class:`SyncMultiGPUTrainerParameterServer`
"""
apply_prefetch_policy(config, gpu_prefetch)
self._input_source = config.data self._input_source = config.data
logger.warn("Note that SyncMultiGPUTrainerReplicated doesn't support inference.") logger.warn("Note that SyncMultiGPUTrainerReplicated doesn't support inference.")
super(SyncMultiGPUTrainerReplicated, self).__init__(config) super(SyncMultiGPUTrainerReplicated, self).__init__(config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment