Commit 875f4d7d authored by Yuxin Wu's avatar Yuxin Wu

Let get_predictor take GPU id, and remove get_predictors

parent cbb26847
......@@ -360,6 +360,9 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'remap_get_variable',
'freeze_get_variable',
'Triggerable',
'predictor_factory',
'get_predictors',
'vs_name_for_predictor',
'dump_chkpt_vars',
'ParamRestore']:
return True
......
......@@ -28,7 +28,8 @@ You can override any of the following methods to define a new callback:
Setup the ops / tensors in the graph which you might need to use in the callback. You can use
[`graph.get_tensor_by_name`](https://www.tensorflow.org/api_docs/python/tf/Graph#get_tensor_by_name)
to access those already defined in the training tower. Or use
to access those already defined in the training tower.
Or use
[`self.trainer.get_predictor(..)`](http://tensorpack.readthedocs.io/en/latest/modules/train.html?highlight=get_predictor#tensorpack.train.Trainer.get_predictor)
to create a callable evaluation function in the predict tower.
......
......@@ -144,15 +144,21 @@ class Model(ModelDesc):
class MySimulatorMaster(SimulatorMaster, Callback):
def __init__(self, pipe_c2s, pipe_s2c, model):
def __init__(self, pipe_c2s, pipe_s2c, model, gpus):
super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c)
self.M = model
self.queue = queue.Queue(maxsize=BATCH_SIZE * 8 * 2)
self._gpus = gpus
def _setup_graph(self):
# create predictors on the available predictor GPUs.
nr_gpu = len(self._gpus)
predictors = [self.trainer.get_predictor(
['state'], ['policy', 'pred_value'],
self._gpus[k % nr_gpu])
for k in range(PREDICTOR_THREAD)]
self.async_predictor = MultiThreadAsyncPredictor(
self.trainer.get_predictors(['state'], ['policy', 'pred_value'],
PREDICTOR_THREAD), batch_size=PREDICT_BATCH_SIZE)
predictors, batch_size=PREDICT_BATCH_SIZE)
def _before_train(self):
self.async_predictor.start()
......@@ -201,8 +207,23 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def get_config():
M = Model()
nr_gpu = get_nr_gpu()
if nr_gpu > 0:
if nr_gpu > 1:
# use half gpus for inference
predict_tower = list(range(nr_gpu))[-nr_gpu // 2:]
else:
predict_tower = [0]
PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU
train_tower = list(range(nr_gpu))[:-nr_gpu // 2] or [0]
logger.info("[Batch-A3C] Train on gpu {} and infer on gpu {}".format(
','.join(map(str, train_tower)), ','.join(map(str, predict_tower))))
else:
logger.warn("Without GPU this model will never learn! CPU is only useful for debug.")
PREDICTOR_THREAD = 1
predict_tower, train_tower = [0], [0]
# setup simulator processes
name_base = str(uuid.uuid1())[:6]
PIPE_DIR = os.environ.get('TENSORPACK_PIPEDIR', '.').rstrip('/')
namec2s = 'ipc://{}/sim-c2s-{}'.format(PIPE_DIR, name_base)
......@@ -211,7 +232,8 @@ def get_config():
ensure_proc_terminate(procs)
start_proc_mask_signal(procs)
master = MySimulatorMaster(namec2s, names2c, M)
M = Model()
master = MySimulatorMaster(namec2s, names2c, M, predict_tower)
dataflow = BatchData(DataFromQueue(master.queue), BATCH_SIZE)
return TrainConfig(
model=M,
......@@ -232,6 +254,7 @@ def get_config():
config=get_default_sess_config(0.5)),
steps_per_epoch=STEPS_PER_EPOCH,
max_epoch=1000,
tower=train_tower
)
......@@ -274,27 +297,8 @@ if __name__ == '__main__':
dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME))
logger.set_logger_dir(dirname)
nr_gpu = get_nr_gpu()
trainer = QueueInputTrainer
if nr_gpu > 0:
if nr_gpu > 1:
predict_tower = list(range(nr_gpu))[-nr_gpu // 2:]
else:
predict_tower = [0]
PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU
train_tower = list(range(nr_gpu))[:-nr_gpu // 2] or [0]
logger.info("[Batch-A3C] Train on gpu {} and infer on gpu {}".format(
','.join(map(str, train_tower)), ','.join(map(str, predict_tower))))
if len(train_tower) > 1:
trainer = AsyncMultiGPUTrainer
else:
logger.warn("Without GPU this model will never learn! CPU is only useful for debug.")
PREDICTOR_THREAD = 1
predict_tower, train_tower = [0], [0]
trainer = QueueInputTrainer
config = get_config()
if args.load:
config.session_init = get_model_loader(args.load)
config.tower = train_tower
config.predict_tower = predict_tower
trainer = QueueInputTrainer if config.nr_tower == 1 else AsyncMultiGPUTrainer
trainer(config).train()
......@@ -41,16 +41,13 @@ class PredictorTowerHandle(object):
class PredictorFactory(object):
""" Make predictors from :class:`ModelDesc`."""
def __init__(self, model, towers, vs_name=''):
def __init__(self, model, vs_name=''):
"""
Args:
model (ModelDesc):
towers (list[int]): list of available gpu id
vs_name (str):
"""
assert isinstance(towers, list), towers
self._model = model
self._towers = towers
self._vs_name = vs_name
self._names_built = {}
......@@ -82,12 +79,11 @@ class PredictorFactory(object):
def get_predictor(self, input_names, output_names, tower):
"""
Args:
tower (int): need the kth tower (not the gpu id, but the id in TrainConfig.predict_tower)
tower (int): use device '/gpu:{tower}' or use -1 for '/cpu:0'.
Returns:
an online predictor (which has to be used under a default session)
"""
tower_name = 'towerp{}'.format(tower)
tower = self._towers[tower]
device = '/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0'
# use a previously-built tower
# TODO check conflict with inference runner??
......
......@@ -11,6 +11,7 @@ import tensorflow as tf
from ..graph_builder.predictor_factory import PredictorFactory
from .config import TrainConfig
from ..utils import logger
from ..utils.develop import deprecated
from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value
......@@ -234,27 +235,22 @@ class Trainer(object):
"""
Args:
input_names (list), output_names(list): list of names
tower (int): return the predictor on the kth tower, defined by ``config.predict_tower``.
tower (int): build the predictor on device '/gpu:{tower}' or use -1 for '/cpu:0'.
Returns:
an :class:`OnlinePredictor`.
"""
# TODO move the logic to factory?
nr_tower = len(self.config.predict_tower)
if nr_tower < tower:
logger.warn(
"Requested the {}th predictor but only have {} predict towers! "
"Predictors will be assigned to GPUs in round-robin.".format(tower, nr_tower))
tower = tower % nr_tower
return self.predictor_factory.get_predictor(input_names, output_names, tower)
@property
def predictor_factory(self):
if not hasattr(self, '_predictor_factory'):
self._predictor_factory = PredictorFactory(
self.model, self.config.predict_tower, self.vs_name_for_predictor)
self.model, self.vs_name_for_predictor)
return self._predictor_factory
@deprecated("Please call `Trainer.get_predictor` to create them manually.")
def get_predictors(self, input_names, output_names, n):
""" Return n predictors. """
return [self.get_predictor(input_names, output_names, k) for k in range(n)]
......@@ -43,11 +43,12 @@ class TrainConfig(object):
callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are
``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries(), RunUpdateOps()]``. The list of
callbacks that will be used in the end are ``callbacks + extra_callbacks``.
is only used to provide the defaults in addition to ``callbacks``. The defaults are
``MovingAverageSummary()``, ``ProgressBar()``,
``MergeAllSummaries()``, ``RunUpdateOps()``. The list of
callbacks that will be used in the end is ``callbacks + extra_callbacks``.
monitors (list): a list of :class:`TrainingMonitor`.
Defaults to ``[TFEventWriter(), JSONWriter(), ScalarPrinter()]``.
Defaults to ``TFEventWriter()``, ``JSONWriter()``, ``ScalarPrinter()``.
session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()`
with the config returned by :func:`tfutils.get_default_sess_config()`.
......@@ -59,8 +60,8 @@ class TrainConfig(object):
Defaults to the input data size.
max_epoch (int): maximum number of epoch to run training.
nr_tower (int): number of training towers.
tower (list of int): list of training towers in relative id.
nr_tower (int): number of training towers, used by multigpu trainers.
tower (list of int): list of training towers in relative GPU id.
predict_tower (list of int): list of prediction towers in their relative gpu id. Use -1 for cpu.
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment