Commit 1dbf6154 authored by Yuxin Wu's avatar Yuxin Wu

Still use gpu as default ps device

parent 710bf4eb
......@@ -151,19 +151,16 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
See https://www.tensorflow.org/performance/benchmarks for details.
"""
def __init__(self, config, ps_device=None, gpu_prefetch=True):
def __init__(self, config, ps_device='gpu', gpu_prefetch=True):
"""
Args:
config(TrainConfig): Must contain 'model' and either one of 'data' or 'dataflow'.
ps_device: either 'gpu' or 'cpu', where variables are stored. Setting to 'cpu' might help if #gpu>=4
Defaults to 'cpu' when #gpu >= 4.
ps_device: either 'gpu' or 'cpu', where variables are stored. Setting to 'cpu' might help when #gpu>=4
gpu_prefetch(bool): whether to prefetch the data to each GPU. Usually improve performance.
"""
apply_prefetch_policy(config, gpu_prefetch)
self._input_source = config.data
if ps_device is None:
ps_device = 'cpu' if config.nr_tower >= 4 else 'gpu'
assert ps_device in ['gpu', 'cpu'], ps_device
self._ps_device = ps_device
super(SyncMultiGPUTrainerParameterServer, self).__init__(config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment