Commit be3a07a1 authored by Yuxin Wu's avatar Yuxin Wu

config.tower shouldn't be used for v2. towers are in trainers already.

parent 3d30826c
......@@ -139,9 +139,8 @@ class Model(ModelDesc):
class MySimulatorMaster(SimulatorMaster, Callback):
def __init__(self, pipe_c2s, pipe_s2c, model, gpus):
def __init__(self, pipe_c2s, pipe_s2c, gpus):
super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c)
self.M = model
self.queue = queue.Queue(maxsize=BATCH_SIZE * 8 * 2)
self._gpus = gpus
......@@ -211,7 +210,11 @@ class MySimulatorMaster(SimulatorMaster, Callback):
client.memory = []
def get_config():
def train():
dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME))
logger.set_logger_dir(dirname)
# assign GPUs for training & inference
nr_gpu = get_nr_gpu()
global PREDICTOR_THREAD
if nr_gpu > 0:
......@@ -238,11 +241,10 @@ def get_config():
ensure_proc_terminate(procs)
start_proc_mask_signal(procs)
M = Model()
master = MySimulatorMaster(namec2s, names2c, M, predict_tower)
master = MySimulatorMaster(namec2s, names2c, predict_tower)
dataflow = BatchData(DataFromQueue(master.queue), BATCH_SIZE)
return TrainConfig(
model=M,
config = TrainConfig(
model=Model(),
dataflow=dataflow,
callbacks=[
ModelSaver(),
......@@ -259,9 +261,11 @@ def get_config():
session_creator=sesscreate.NewSessionCreator(
config=get_default_sess_config(0.5)),
steps_per_epoch=STEPS_PER_EPOCH,
session_init=get_model_loader(args.load) if args.load else None,
max_epoch=1000,
tower=train_tower
)
trainer = SimpleTrainer() if config.nr_tower == 1 else AsyncMultiGPUTrainer(train_tower)
launch_train_with_config(config, trainer)
if __name__ == '__main__':
......@@ -301,11 +305,4 @@ if __name__ == '__main__':
pred, args.episode)
# gym.upload(args.output, api_key='xxx')
else:
dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME))
logger.set_logger_dir(dirname)
config = get_config()
if args.load:
config.session_init = get_model_loader(args.load)
trainer = SimpleTrainer() if config.nr_tower == 1 else AsyncMultiGPUTrainer(config.tower)
launch_train_with_config(config, trainer)
train()
......@@ -84,9 +84,6 @@ class TrainConfig(object):
steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
Defaults to the input data size.
max_epoch (int): maximum number of epoch to run training.
nr_tower (int): number of training towers, used by multigpu trainers.
tower ([int]): list of training towers in relative GPU id.
"""
# TODO type checker decorator
......@@ -147,6 +144,7 @@ class TrainConfig(object):
self.max_epoch = int(max_epoch)
assert self.steps_per_epoch > 0 and self.max_epoch > 0
# Tower stuff are for Trainer v1 only:
nr_tower = max(nr_tower, 1)
self.nr_tower = nr_tower
if tower is not None:
......@@ -160,6 +158,7 @@ class TrainConfig(object):
self.predict_tower = predict_tower
if isinstance(self.predict_tower, int):
self.predict_tower = [self.predict_tower]
# --------------------------------------------------------------
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
......
......@@ -6,6 +6,7 @@ import tensorflow as tf
from ..input_source import (
InputSource, FeedInput, QueueInput, StagingInput, DummyConstantInput)
from ..utils import logger
from .config import TrainConfig
from .tower import SingleCostTrainer
......@@ -14,14 +15,13 @@ from .trainers import SimpleTrainer
__all__ = ['launch_train_with_config', 'apply_default_prefetch']
def apply_default_prefetch(input_source_or_dataflow, trainer, towers):
def apply_default_prefetch(input_source_or_dataflow, trainer):
"""
Apply a set of default rules to make a fast :class:`InputSource`.
Args:
input_source_or_dataflow(InputSource | DataFlow):
trainer (Trainer):
towers ([int]): list of GPU ids.
"""
if not isinstance(input_source_or_dataflow, InputSource):
# to mimic same behavior of the old trainer interface
......@@ -31,13 +31,15 @@ def apply_default_prefetch(input_source_or_dataflow, trainer, towers):
input = QueueInput(input_source_or_dataflow)
else:
input = input_source_or_dataflow
if len(towers) > 1:
# seem to only improve on >1 GPUs
assert not isinstance(trainer, SimpleTrainer)
assert tf.test.is_gpu_available()
if not isinstance(input, (StagingInput, DummyConstantInput)):
input = StagingInput(input, towers)
if hasattr(trainer, 'devices'):
towers = trainer.devices
if len(towers) > 1:
# seem to only improve on >1 GPUs
assert not isinstance(trainer, SimpleTrainer)
assert tf.test.is_gpu_available()
if not isinstance(input, (StagingInput, DummyConstantInput)):
input = StagingInput(input, towers)
return input
......@@ -75,7 +77,10 @@ def launch_train_with_config(config, trainer):
model = config.model
inputs_desc = model.get_inputs_desc()
input = config.data or config.dataflow
input = apply_default_prefetch(input, trainer, config.tower)
input = apply_default_prefetch(input, trainer)
if config.nr_tower > 1:
logger.warn("With trainer v2, setting tower in TrainConfig has no effect.")
logger.warn("It's enough to set the tower when initializing the trainer.")
trainer.setup_graph(
inputs_desc, input,
......
......@@ -68,6 +68,8 @@ class TowerTrainer(Trainer):
Returns:
a :class:`TowerTensorHandles` object, to
access the tower handles by either indices or names.
It is accessbile only after the graph is set up.
"""
return self.tower_func.towers
......
......@@ -54,7 +54,7 @@ class SimpleTrainer(SingleCostTrainer):
return []
# Only works for type check
# Only exists for type check & back-compatibility
class QueueInputTrainer(SimpleTrainer):
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
assert isinstance(input, QueueInput)
......@@ -65,6 +65,11 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
__doc__ = SyncMultiGPUParameterServerBuilder.__doc__
devices = None
"""
List of GPU ids.
"""
@map_arg(gpus=_int_to_range)
def __init__(self, gpus, ps_device='gpu'):
"""
......@@ -72,6 +77,7 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
gpus ([int]): list of GPU ids.
ps_device: either 'gpu' or 'cpu', where variables are stored. Setting to 'cpu' might help when #gpu>=4
"""
self.devices = gpus
self._builder = SyncMultiGPUParameterServerBuilder(gpus, ps_device)
super(SyncMultiGPUTrainerParameterServer, self).__init__()
......@@ -96,6 +102,11 @@ class AsyncMultiGPUTrainer(SingleCostTrainer):
__doc__ = AsyncMultiGPUBuilder.__doc__
devices = None
"""
List of GPU ids.
"""
@map_arg(gpus=_int_to_range)
def __init__(self, gpus, scale_gradient=True):
"""
......@@ -103,6 +114,7 @@ class AsyncMultiGPUTrainer(SingleCostTrainer):
gpus ([int]): list of GPU ids.
scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``.
"""
self.devices = gpus
self._builder = AsyncMultiGPUBuilder(gpus, scale_gradient)
super(AsyncMultiGPUTrainer, self).__init__()
......@@ -116,12 +128,18 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
__doc__ = SyncMultiGPUReplicatedBuilder.__doc__
devices = None
"""
List of GPU ids.
"""
@map_arg(gpus=_int_to_range)
def __init__(self, gpus):
"""
Args:
gpus ([int]): list of GPU ids.
"""
self.devices = gpus
self._builder = SyncMultiGPUReplicatedBuilder(gpus)
super(SyncMultiGPUTrainerReplicated, self).__init__()
......@@ -139,6 +157,11 @@ class DistributedTrainerReplicated(SingleCostTrainer):
__doc__ = DistributedReplicatedBuilder.__doc__
devices = None
"""
List of GPU ids.
"""
@map_arg(gpus=_int_to_range)
def __init__(self, gpus, server):
"""
......@@ -146,6 +169,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
gpus (list[int]): list of GPU ids.
server (tf.train.Server): the server with ps and workers.
"""
self.devices = gpus
self.server = server
self.job_name = server.server_def.job_name
assert self.job_name in ['ps', 'worker'], self.job_name
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment