Commit be3a07a1 authored by Yuxin Wu's avatar Yuxin Wu

config.tower shouldn't be used for v2. towers are in trainers already.

parent 3d30826c
...@@ -139,9 +139,8 @@ class Model(ModelDesc): ...@@ -139,9 +139,8 @@ class Model(ModelDesc):
class MySimulatorMaster(SimulatorMaster, Callback): class MySimulatorMaster(SimulatorMaster, Callback):
def __init__(self, pipe_c2s, pipe_s2c, model, gpus): def __init__(self, pipe_c2s, pipe_s2c, gpus):
super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c) super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c)
self.M = model
self.queue = queue.Queue(maxsize=BATCH_SIZE * 8 * 2) self.queue = queue.Queue(maxsize=BATCH_SIZE * 8 * 2)
self._gpus = gpus self._gpus = gpus
...@@ -211,7 +210,11 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -211,7 +210,11 @@ class MySimulatorMaster(SimulatorMaster, Callback):
client.memory = [] client.memory = []
def get_config(): def train():
dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME))
logger.set_logger_dir(dirname)
# assign GPUs for training & inference
nr_gpu = get_nr_gpu() nr_gpu = get_nr_gpu()
global PREDICTOR_THREAD global PREDICTOR_THREAD
if nr_gpu > 0: if nr_gpu > 0:
...@@ -238,11 +241,10 @@ def get_config(): ...@@ -238,11 +241,10 @@ def get_config():
ensure_proc_terminate(procs) ensure_proc_terminate(procs)
start_proc_mask_signal(procs) start_proc_mask_signal(procs)
M = Model() master = MySimulatorMaster(namec2s, names2c, predict_tower)
master = MySimulatorMaster(namec2s, names2c, M, predict_tower)
dataflow = BatchData(DataFromQueue(master.queue), BATCH_SIZE) dataflow = BatchData(DataFromQueue(master.queue), BATCH_SIZE)
return TrainConfig( config = TrainConfig(
model=M, model=Model(),
dataflow=dataflow, dataflow=dataflow,
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
...@@ -259,9 +261,11 @@ def get_config(): ...@@ -259,9 +261,11 @@ def get_config():
session_creator=sesscreate.NewSessionCreator( session_creator=sesscreate.NewSessionCreator(
config=get_default_sess_config(0.5)), config=get_default_sess_config(0.5)),
steps_per_epoch=STEPS_PER_EPOCH, steps_per_epoch=STEPS_PER_EPOCH,
session_init=get_model_loader(args.load) if args.load else None,
max_epoch=1000, max_epoch=1000,
tower=train_tower
) )
trainer = SimpleTrainer() if config.nr_tower == 1 else AsyncMultiGPUTrainer(train_tower)
launch_train_with_config(config, trainer)
if __name__ == '__main__': if __name__ == '__main__':
...@@ -301,11 +305,4 @@ if __name__ == '__main__': ...@@ -301,11 +305,4 @@ if __name__ == '__main__':
pred, args.episode) pred, args.episode)
# gym.upload(args.output, api_key='xxx') # gym.upload(args.output, api_key='xxx')
else: else:
dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME)) train()
logger.set_logger_dir(dirname)
config = get_config()
if args.load:
config.session_init = get_model_loader(args.load)
trainer = SimpleTrainer() if config.nr_tower == 1 else AsyncMultiGPUTrainer(config.tower)
launch_train_with_config(config, trainer)
...@@ -84,9 +84,6 @@ class TrainConfig(object): ...@@ -84,9 +84,6 @@ class TrainConfig(object):
steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch. steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
Defaults to the input data size. Defaults to the input data size.
max_epoch (int): maximum number of epoch to run training. max_epoch (int): maximum number of epoch to run training.
nr_tower (int): number of training towers, used by multigpu trainers.
tower ([int]): list of training towers in relative GPU id.
""" """
# TODO type checker decorator # TODO type checker decorator
...@@ -147,6 +144,7 @@ class TrainConfig(object): ...@@ -147,6 +144,7 @@ class TrainConfig(object):
self.max_epoch = int(max_epoch) self.max_epoch = int(max_epoch)
assert self.steps_per_epoch > 0 and self.max_epoch > 0 assert self.steps_per_epoch > 0 and self.max_epoch > 0
# Tower stuff are for Trainer v1 only:
nr_tower = max(nr_tower, 1) nr_tower = max(nr_tower, 1)
self.nr_tower = nr_tower self.nr_tower = nr_tower
if tower is not None: if tower is not None:
...@@ -160,6 +158,7 @@ class TrainConfig(object): ...@@ -160,6 +158,7 @@ class TrainConfig(object):
self.predict_tower = predict_tower self.predict_tower = predict_tower
if isinstance(self.predict_tower, int): if isinstance(self.predict_tower, int):
self.predict_tower = [self.predict_tower] self.predict_tower = [self.predict_tower]
# --------------------------------------------------------------
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
......
...@@ -6,6 +6,7 @@ import tensorflow as tf ...@@ -6,6 +6,7 @@ import tensorflow as tf
from ..input_source import ( from ..input_source import (
InputSource, FeedInput, QueueInput, StagingInput, DummyConstantInput) InputSource, FeedInput, QueueInput, StagingInput, DummyConstantInput)
from ..utils import logger
from .config import TrainConfig from .config import TrainConfig
from .tower import SingleCostTrainer from .tower import SingleCostTrainer
...@@ -14,14 +15,13 @@ from .trainers import SimpleTrainer ...@@ -14,14 +15,13 @@ from .trainers import SimpleTrainer
__all__ = ['launch_train_with_config', 'apply_default_prefetch'] __all__ = ['launch_train_with_config', 'apply_default_prefetch']
def apply_default_prefetch(input_source_or_dataflow, trainer, towers): def apply_default_prefetch(input_source_or_dataflow, trainer):
""" """
Apply a set of default rules to make a fast :class:`InputSource`. Apply a set of default rules to make a fast :class:`InputSource`.
Args: Args:
input_source_or_dataflow(InputSource | DataFlow): input_source_or_dataflow(InputSource | DataFlow):
trainer (Trainer): trainer (Trainer):
towers ([int]): list of GPU ids.
""" """
if not isinstance(input_source_or_dataflow, InputSource): if not isinstance(input_source_or_dataflow, InputSource):
# to mimic same behavior of the old trainer interface # to mimic same behavior of the old trainer interface
...@@ -31,13 +31,15 @@ def apply_default_prefetch(input_source_or_dataflow, trainer, towers): ...@@ -31,13 +31,15 @@ def apply_default_prefetch(input_source_or_dataflow, trainer, towers):
input = QueueInput(input_source_or_dataflow) input = QueueInput(input_source_or_dataflow)
else: else:
input = input_source_or_dataflow input = input_source_or_dataflow
if len(towers) > 1: if hasattr(trainer, 'devices'):
# seem to only improve on >1 GPUs towers = trainer.devices
assert not isinstance(trainer, SimpleTrainer) if len(towers) > 1:
assert tf.test.is_gpu_available() # seem to only improve on >1 GPUs
assert not isinstance(trainer, SimpleTrainer)
if not isinstance(input, (StagingInput, DummyConstantInput)): assert tf.test.is_gpu_available()
input = StagingInput(input, towers)
if not isinstance(input, (StagingInput, DummyConstantInput)):
input = StagingInput(input, towers)
return input return input
...@@ -75,7 +77,10 @@ def launch_train_with_config(config, trainer): ...@@ -75,7 +77,10 @@ def launch_train_with_config(config, trainer):
model = config.model model = config.model
inputs_desc = model.get_inputs_desc() inputs_desc = model.get_inputs_desc()
input = config.data or config.dataflow input = config.data or config.dataflow
input = apply_default_prefetch(input, trainer, config.tower) input = apply_default_prefetch(input, trainer)
if config.nr_tower > 1:
logger.warn("With trainer v2, setting tower in TrainConfig has no effect.")
logger.warn("It's enough to set the tower when initializing the trainer.")
trainer.setup_graph( trainer.setup_graph(
inputs_desc, input, inputs_desc, input,
......
...@@ -68,6 +68,8 @@ class TowerTrainer(Trainer): ...@@ -68,6 +68,8 @@ class TowerTrainer(Trainer):
Returns: Returns:
a :class:`TowerTensorHandles` object, to a :class:`TowerTensorHandles` object, to
access the tower handles by either indices or names. access the tower handles by either indices or names.
It is accessbile only after the graph is set up.
""" """
return self.tower_func.towers return self.tower_func.towers
......
...@@ -54,7 +54,7 @@ class SimpleTrainer(SingleCostTrainer): ...@@ -54,7 +54,7 @@ class SimpleTrainer(SingleCostTrainer):
return [] return []
# Only works for type check # Only exists for type check & back-compatibility
class QueueInputTrainer(SimpleTrainer): class QueueInputTrainer(SimpleTrainer):
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
assert isinstance(input, QueueInput) assert isinstance(input, QueueInput)
...@@ -65,6 +65,11 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer): ...@@ -65,6 +65,11 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
__doc__ = SyncMultiGPUParameterServerBuilder.__doc__ __doc__ = SyncMultiGPUParameterServerBuilder.__doc__
devices = None
"""
List of GPU ids.
"""
@map_arg(gpus=_int_to_range) @map_arg(gpus=_int_to_range)
def __init__(self, gpus, ps_device='gpu'): def __init__(self, gpus, ps_device='gpu'):
""" """
...@@ -72,6 +77,7 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer): ...@@ -72,6 +77,7 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
gpus ([int]): list of GPU ids. gpus ([int]): list of GPU ids.
ps_device: either 'gpu' or 'cpu', where variables are stored. Setting to 'cpu' might help when #gpu>=4 ps_device: either 'gpu' or 'cpu', where variables are stored. Setting to 'cpu' might help when #gpu>=4
""" """
self.devices = gpus
self._builder = SyncMultiGPUParameterServerBuilder(gpus, ps_device) self._builder = SyncMultiGPUParameterServerBuilder(gpus, ps_device)
super(SyncMultiGPUTrainerParameterServer, self).__init__() super(SyncMultiGPUTrainerParameterServer, self).__init__()
...@@ -96,6 +102,11 @@ class AsyncMultiGPUTrainer(SingleCostTrainer): ...@@ -96,6 +102,11 @@ class AsyncMultiGPUTrainer(SingleCostTrainer):
__doc__ = AsyncMultiGPUBuilder.__doc__ __doc__ = AsyncMultiGPUBuilder.__doc__
devices = None
"""
List of GPU ids.
"""
@map_arg(gpus=_int_to_range) @map_arg(gpus=_int_to_range)
def __init__(self, gpus, scale_gradient=True): def __init__(self, gpus, scale_gradient=True):
""" """
...@@ -103,6 +114,7 @@ class AsyncMultiGPUTrainer(SingleCostTrainer): ...@@ -103,6 +114,7 @@ class AsyncMultiGPUTrainer(SingleCostTrainer):
gpus ([int]): list of GPU ids. gpus ([int]): list of GPU ids.
scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``. scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``.
""" """
self.devices = gpus
self._builder = AsyncMultiGPUBuilder(gpus, scale_gradient) self._builder = AsyncMultiGPUBuilder(gpus, scale_gradient)
super(AsyncMultiGPUTrainer, self).__init__() super(AsyncMultiGPUTrainer, self).__init__()
...@@ -116,12 +128,18 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer): ...@@ -116,12 +128,18 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
__doc__ = SyncMultiGPUReplicatedBuilder.__doc__ __doc__ = SyncMultiGPUReplicatedBuilder.__doc__
devices = None
"""
List of GPU ids.
"""
@map_arg(gpus=_int_to_range) @map_arg(gpus=_int_to_range)
def __init__(self, gpus): def __init__(self, gpus):
""" """
Args: Args:
gpus ([int]): list of GPU ids. gpus ([int]): list of GPU ids.
""" """
self.devices = gpus
self._builder = SyncMultiGPUReplicatedBuilder(gpus) self._builder = SyncMultiGPUReplicatedBuilder(gpus)
super(SyncMultiGPUTrainerReplicated, self).__init__() super(SyncMultiGPUTrainerReplicated, self).__init__()
...@@ -139,6 +157,11 @@ class DistributedTrainerReplicated(SingleCostTrainer): ...@@ -139,6 +157,11 @@ class DistributedTrainerReplicated(SingleCostTrainer):
__doc__ = DistributedReplicatedBuilder.__doc__ __doc__ = DistributedReplicatedBuilder.__doc__
devices = None
"""
List of GPU ids.
"""
@map_arg(gpus=_int_to_range) @map_arg(gpus=_int_to_range)
def __init__(self, gpus, server): def __init__(self, gpus, server):
""" """
...@@ -146,6 +169,7 @@ class DistributedTrainerReplicated(SingleCostTrainer): ...@@ -146,6 +169,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
gpus (list[int]): list of GPU ids. gpus (list[int]): list of GPU ids.
server (tf.train.Server): the server with ps and workers. server (tf.train.Server): the server with ps and workers.
""" """
self.devices = gpus
self.server = server self.server = server
self.job_name = server.server_def.job_name self.job_name = server.server_def.job_name
assert self.job_name in ['ps', 'worker'], self.job_name assert self.job_name in ['ps', 'worker'], self.job_name
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment