Commit 22b91be9 authored by Yuxin Wu's avatar Yuxin Wu

map_arg for gpus

parent dc709e94
......@@ -322,5 +322,4 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.nr_tower = nr_tower
launch_train_with_config(configi, SyncMultiGPUTrainer(list(range(nr_tower))))
launch_train_with_config(config, SyncMultiGPUTrainer(nr_tower))
......@@ -264,4 +264,4 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
launch_train_with_config(config, SyncMultiGPUTrainer(list(range(NR_GPU))))
launch_train_with_config(config, SyncMultiGPUTrainer(NR_GPU))
......@@ -302,5 +302,5 @@ if __name__ == '__main__':
max_epoch=205000 * factor // stepnum,
session_init=get_model_loader(args.load) if args.load else None,
)
trainer = SyncMultiGPUTrainerReplicated(range(len(get_nr_gpu())))
trainer = SyncMultiGPUTrainerReplicated(get_nr_gpu())
launch_train_with_config(cfg, trainer)
......@@ -234,4 +234,4 @@ if __name__ == '__main__':
config.session_init = get_model_loader(args.load)
launch_train_with_config(
config,
SyncMultiGPUTrainer(range(max(get_nr_gpu(), 1))))
SyncMultiGPUTrainer(max(get_nr_gpu(), 1)))
......@@ -8,6 +8,7 @@ from ..callbacks.graph import RunOp
from ..tfutils.sesscreate import NewSessionCreator
from ..utils import logger
from ..utils.argtools import map_arg
from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator
from ..tfutils.tower import TowerContext
......@@ -31,6 +32,12 @@ __all__ = ['SimpleTrainer',
'DistributedTrainerReplicated']
def _int_to_range(x):
if isinstance(x, int):
assert x > 0, x
return list(range(x))
class SimpleTrainer(SingleCostTrainer):
"""
Single-GPU single-cost single-tower trainer.
......@@ -54,13 +61,14 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
__doc__ = SyncMultiGPUParameterServerBuilder.__doc__
def __init__(self, towers, ps_device='gpu'):
@map_arg(gpus=_int_to_range)
def __init__(self, gpus, ps_device='gpu'):
"""
Args:
towers ([int]): list of GPU ids.
gpus ([int]): list of GPU ids.
ps_device: either 'gpu' or 'cpu', where variables are stored. Setting to 'cpu' might help when #gpu>=4
"""
self._builder = SyncMultiGPUParameterServerBuilder(towers, ps_device)
self._builder = SyncMultiGPUParameterServerBuilder(gpus, ps_device)
super(SyncMultiGPUTrainerParameterServer, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
......@@ -69,28 +77,29 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
return []
def SyncMultiGPUTrainer(towers):
def SyncMultiGPUTrainer(gpus):
"""
Return a default multi-GPU trainer, if you don't care about the details.
It may not be the most efficient one for your task.
Args:
towers (list[int]): list of GPU ids.
gpus (list[int]): list of GPU ids.
"""
return SyncMultiGPUTrainerParameterServer(towers, ps_device='gpu')
return SyncMultiGPUTrainerParameterServer(gpus, ps_device='gpu')
class AsyncMultiGPUTrainer(SingleCostTrainer):
__doc__ = AsyncMultiGPUBuilder.__doc__
def __init__(self, towers, scale_gradient=True):
@map_arg(gpus=_int_to_range)
def __init__(self, gpus, scale_gradient=True):
"""
Args:
towers ([int]): list of GPU ids.
gpus ([int]): list of GPU ids.
scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``.
"""
self._builder = AsyncMultiGPUBuilder(towers, scale_gradient)
self._builder = AsyncMultiGPUBuilder(gpus, scale_gradient)
super(AsyncMultiGPUTrainer, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
......@@ -103,12 +112,13 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
__doc__ = SyncMultiGPUReplicatedBuilder.__doc__
def __init__(self, towers):
@map_arg(gpus=_int_to_range)
def __init__(self, gpus):
"""
Args:
towers ([int]): list of GPU ids.
gpus ([int]): list of GPU ids.
"""
self._builder = SyncMultiGPUReplicatedBuilder(towers)
self._builder = SyncMultiGPUReplicatedBuilder(gpus)
super(SyncMultiGPUTrainerReplicated, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
......@@ -125,10 +135,11 @@ class DistributedTrainerReplicated(SingleCostTrainer):
__doc__ = DistributedReplicatedBuilder.__doc__
def __init__(self, towers, server):
@map_arg(gpus=_int_to_range)
def __init__(self, gpus, server):
"""
Args:
towers (list[int]): list of GPU ids.
gpus (list[int]): list of GPU ids.
server (tf.train.Server): the server with ps and workers.
The job_name must be 'worker' because 'ps' job doesn't need to
build any graph.
......@@ -139,7 +150,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
if self.job_name == 'worker':
# ps doesn't build any graph
self._builder = DistributedReplicatedBuilder(towers, server)
self._builder = DistributedReplicatedBuilder(gpus, server)
self.is_chief = self._builder.is_chief
else:
self.is_chief = False
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment