Commit 22b91be9 authored by Yuxin Wu's avatar Yuxin Wu

map_arg for gpus

parent dc709e94
...@@ -322,5 +322,4 @@ if __name__ == '__main__': ...@@ -322,5 +322,4 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
config.nr_tower = nr_tower launch_train_with_config(config, SyncMultiGPUTrainer(nr_tower))
launch_train_with_config(configi, SyncMultiGPUTrainer(list(range(nr_tower))))
...@@ -264,4 +264,4 @@ if __name__ == '__main__': ...@@ -264,4 +264,4 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
launch_train_with_config(config, SyncMultiGPUTrainer(list(range(NR_GPU)))) launch_train_with_config(config, SyncMultiGPUTrainer(NR_GPU))
...@@ -302,5 +302,5 @@ if __name__ == '__main__': ...@@ -302,5 +302,5 @@ if __name__ == '__main__':
max_epoch=205000 * factor // stepnum, max_epoch=205000 * factor // stepnum,
session_init=get_model_loader(args.load) if args.load else None, session_init=get_model_loader(args.load) if args.load else None,
) )
trainer = SyncMultiGPUTrainerReplicated(range(len(get_nr_gpu()))) trainer = SyncMultiGPUTrainerReplicated(get_nr_gpu())
launch_train_with_config(cfg, trainer) launch_train_with_config(cfg, trainer)
...@@ -234,4 +234,4 @@ if __name__ == '__main__': ...@@ -234,4 +234,4 @@ if __name__ == '__main__':
config.session_init = get_model_loader(args.load) config.session_init = get_model_loader(args.load)
launch_train_with_config( launch_train_with_config(
config, config,
SyncMultiGPUTrainer(range(max(get_nr_gpu(), 1)))) SyncMultiGPUTrainer(max(get_nr_gpu(), 1)))
...@@ -8,6 +8,7 @@ from ..callbacks.graph import RunOp ...@@ -8,6 +8,7 @@ from ..callbacks.graph import RunOp
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..utils import logger from ..utils import logger
from ..utils.argtools import map_arg
from ..tfutils import get_global_step_var from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator from ..tfutils.distributed import get_distributed_session_creator
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
...@@ -31,6 +32,12 @@ __all__ = ['SimpleTrainer', ...@@ -31,6 +32,12 @@ __all__ = ['SimpleTrainer',
'DistributedTrainerReplicated'] 'DistributedTrainerReplicated']
def _int_to_range(x):
if isinstance(x, int):
assert x > 0, x
return list(range(x))
class SimpleTrainer(SingleCostTrainer): class SimpleTrainer(SingleCostTrainer):
""" """
Single-GPU single-cost single-tower trainer. Single-GPU single-cost single-tower trainer.
...@@ -54,13 +61,14 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer): ...@@ -54,13 +61,14 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
__doc__ = SyncMultiGPUParameterServerBuilder.__doc__ __doc__ = SyncMultiGPUParameterServerBuilder.__doc__
def __init__(self, towers, ps_device='gpu'): @map_arg(gpus=_int_to_range)
def __init__(self, gpus, ps_device='gpu'):
""" """
Args: Args:
towers ([int]): list of GPU ids. gpus ([int]): list of GPU ids.
ps_device: either 'gpu' or 'cpu', where variables are stored. Setting to 'cpu' might help when #gpu>=4 ps_device: either 'gpu' or 'cpu', where variables are stored. Setting to 'cpu' might help when #gpu>=4
""" """
self._builder = SyncMultiGPUParameterServerBuilder(towers, ps_device) self._builder = SyncMultiGPUParameterServerBuilder(gpus, ps_device)
super(SyncMultiGPUTrainerParameterServer, self).__init__() super(SyncMultiGPUTrainerParameterServer, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
...@@ -69,28 +77,29 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer): ...@@ -69,28 +77,29 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
return [] return []
def SyncMultiGPUTrainer(towers): def SyncMultiGPUTrainer(gpus):
""" """
Return a default multi-GPU trainer, if you don't care about the details. Return a default multi-GPU trainer, if you don't care about the details.
It may not be the most efficient one for your task. It may not be the most efficient one for your task.
Args: Args:
towers (list[int]): list of GPU ids. gpus (list[int]): list of GPU ids.
""" """
return SyncMultiGPUTrainerParameterServer(towers, ps_device='gpu') return SyncMultiGPUTrainerParameterServer(gpus, ps_device='gpu')
class AsyncMultiGPUTrainer(SingleCostTrainer): class AsyncMultiGPUTrainer(SingleCostTrainer):
__doc__ = AsyncMultiGPUBuilder.__doc__ __doc__ = AsyncMultiGPUBuilder.__doc__
def __init__(self, towers, scale_gradient=True): @map_arg(gpus=_int_to_range)
def __init__(self, gpus, scale_gradient=True):
""" """
Args: Args:
towers ([int]): list of GPU ids. gpus ([int]): list of GPU ids.
scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``. scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``.
""" """
self._builder = AsyncMultiGPUBuilder(towers, scale_gradient) self._builder = AsyncMultiGPUBuilder(gpus, scale_gradient)
super(AsyncMultiGPUTrainer, self).__init__() super(AsyncMultiGPUTrainer, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
...@@ -103,12 +112,13 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer): ...@@ -103,12 +112,13 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
__doc__ = SyncMultiGPUReplicatedBuilder.__doc__ __doc__ = SyncMultiGPUReplicatedBuilder.__doc__
def __init__(self, towers): @map_arg(gpus=_int_to_range)
def __init__(self, gpus):
""" """
Args: Args:
towers ([int]): list of GPU ids. gpus ([int]): list of GPU ids.
""" """
self._builder = SyncMultiGPUReplicatedBuilder(towers) self._builder = SyncMultiGPUReplicatedBuilder(gpus)
super(SyncMultiGPUTrainerReplicated, self).__init__() super(SyncMultiGPUTrainerReplicated, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
...@@ -125,10 +135,11 @@ class DistributedTrainerReplicated(SingleCostTrainer): ...@@ -125,10 +135,11 @@ class DistributedTrainerReplicated(SingleCostTrainer):
__doc__ = DistributedReplicatedBuilder.__doc__ __doc__ = DistributedReplicatedBuilder.__doc__
def __init__(self, towers, server): @map_arg(gpus=_int_to_range)
def __init__(self, gpus, server):
""" """
Args: Args:
towers (list[int]): list of GPU ids. gpus (list[int]): list of GPU ids.
server (tf.train.Server): the server with ps and workers. server (tf.train.Server): the server with ps and workers.
The job_name must be 'worker' because 'ps' job doesn't need to The job_name must be 'worker' because 'ps' job doesn't need to
build any graph. build any graph.
...@@ -139,7 +150,7 @@ class DistributedTrainerReplicated(SingleCostTrainer): ...@@ -139,7 +150,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
if self.job_name == 'worker': if self.job_name == 'worker':
# ps doesn't build any graph # ps doesn't build any graph
self._builder = DistributedReplicatedBuilder(towers, server) self._builder = DistributedReplicatedBuilder(gpus, server)
self.is_chief = self._builder.is_chief self.is_chief = self._builder.is_chief
else: else:
self.is_chief = False self.is_chief = False
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment