Commit 875f4d7d authored by Yuxin Wu's avatar Yuxin Wu

Let get_predictor take GPU id, and remove get_predictors

parent cbb26847
...@@ -360,6 +360,9 @@ def autodoc_skip_member(app, what, name, obj, skip, options): ...@@ -360,6 +360,9 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'remap_get_variable', 'remap_get_variable',
'freeze_get_variable', 'freeze_get_variable',
'Triggerable', 'Triggerable',
'predictor_factory',
'get_predictors',
'vs_name_for_predictor',
'dump_chkpt_vars', 'dump_chkpt_vars',
'ParamRestore']: 'ParamRestore']:
return True return True
......
...@@ -28,7 +28,8 @@ You can override any of the following methods to define a new callback: ...@@ -28,7 +28,8 @@ You can override any of the following methods to define a new callback:
Setup the ops / tensors in the graph which you might need to use in the callback. You can use Setup the ops / tensors in the graph which you might need to use in the callback. You can use
[`graph.get_tensor_by_name`](https://www.tensorflow.org/api_docs/python/tf/Graph#get_tensor_by_name) [`graph.get_tensor_by_name`](https://www.tensorflow.org/api_docs/python/tf/Graph#get_tensor_by_name)
to access those already defined in the training tower. Or use to access those already defined in the training tower.
Or use
[`self.trainer.get_predictor(..)`](http://tensorpack.readthedocs.io/en/latest/modules/train.html?highlight=get_predictor#tensorpack.train.Trainer.get_predictor) [`self.trainer.get_predictor(..)`](http://tensorpack.readthedocs.io/en/latest/modules/train.html?highlight=get_predictor#tensorpack.train.Trainer.get_predictor)
to create a callable evaluation function in the predict tower. to create a callable evaluation function in the predict tower.
......
...@@ -144,15 +144,21 @@ class Model(ModelDesc): ...@@ -144,15 +144,21 @@ class Model(ModelDesc):
class MySimulatorMaster(SimulatorMaster, Callback): class MySimulatorMaster(SimulatorMaster, Callback):
def __init__(self, pipe_c2s, pipe_s2c, model): def __init__(self, pipe_c2s, pipe_s2c, model, gpus):
super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c) super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c)
self.M = model self.M = model
self.queue = queue.Queue(maxsize=BATCH_SIZE * 8 * 2) self.queue = queue.Queue(maxsize=BATCH_SIZE * 8 * 2)
self._gpus = gpus
def _setup_graph(self): def _setup_graph(self):
# create predictors on the available predictor GPUs.
nr_gpu = len(self._gpus)
predictors = [self.trainer.get_predictor(
['state'], ['policy', 'pred_value'],
self._gpus[k % nr_gpu])
for k in range(PREDICTOR_THREAD)]
self.async_predictor = MultiThreadAsyncPredictor( self.async_predictor = MultiThreadAsyncPredictor(
self.trainer.get_predictors(['state'], ['policy', 'pred_value'], predictors, batch_size=PREDICT_BATCH_SIZE)
PREDICTOR_THREAD), batch_size=PREDICT_BATCH_SIZE)
def _before_train(self): def _before_train(self):
self.async_predictor.start() self.async_predictor.start()
...@@ -201,8 +207,23 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -201,8 +207,23 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def get_config(): def get_config():
M = Model() nr_gpu = get_nr_gpu()
if nr_gpu > 0:
if nr_gpu > 1:
# use half gpus for inference
predict_tower = list(range(nr_gpu))[-nr_gpu // 2:]
else:
predict_tower = [0]
PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU
train_tower = list(range(nr_gpu))[:-nr_gpu // 2] or [0]
logger.info("[Batch-A3C] Train on gpu {} and infer on gpu {}".format(
','.join(map(str, train_tower)), ','.join(map(str, predict_tower))))
else:
logger.warn("Without GPU this model will never learn! CPU is only useful for debug.")
PREDICTOR_THREAD = 1
predict_tower, train_tower = [0], [0]
# setup simulator processes
name_base = str(uuid.uuid1())[:6] name_base = str(uuid.uuid1())[:6]
PIPE_DIR = os.environ.get('TENSORPACK_PIPEDIR', '.').rstrip('/') PIPE_DIR = os.environ.get('TENSORPACK_PIPEDIR', '.').rstrip('/')
namec2s = 'ipc://{}/sim-c2s-{}'.format(PIPE_DIR, name_base) namec2s = 'ipc://{}/sim-c2s-{}'.format(PIPE_DIR, name_base)
...@@ -211,7 +232,8 @@ def get_config(): ...@@ -211,7 +232,8 @@ def get_config():
ensure_proc_terminate(procs) ensure_proc_terminate(procs)
start_proc_mask_signal(procs) start_proc_mask_signal(procs)
master = MySimulatorMaster(namec2s, names2c, M) M = Model()
master = MySimulatorMaster(namec2s, names2c, M, predict_tower)
dataflow = BatchData(DataFromQueue(master.queue), BATCH_SIZE) dataflow = BatchData(DataFromQueue(master.queue), BATCH_SIZE)
return TrainConfig( return TrainConfig(
model=M, model=M,
...@@ -232,6 +254,7 @@ def get_config(): ...@@ -232,6 +254,7 @@ def get_config():
config=get_default_sess_config(0.5)), config=get_default_sess_config(0.5)),
steps_per_epoch=STEPS_PER_EPOCH, steps_per_epoch=STEPS_PER_EPOCH,
max_epoch=1000, max_epoch=1000,
tower=train_tower
) )
...@@ -274,27 +297,8 @@ if __name__ == '__main__': ...@@ -274,27 +297,8 @@ if __name__ == '__main__':
dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME)) dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME))
logger.set_logger_dir(dirname) logger.set_logger_dir(dirname)
nr_gpu = get_nr_gpu()
trainer = QueueInputTrainer
if nr_gpu > 0:
if nr_gpu > 1:
predict_tower = list(range(nr_gpu))[-nr_gpu // 2:]
else:
predict_tower = [0]
PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU
train_tower = list(range(nr_gpu))[:-nr_gpu // 2] or [0]
logger.info("[Batch-A3C] Train on gpu {} and infer on gpu {}".format(
','.join(map(str, train_tower)), ','.join(map(str, predict_tower))))
if len(train_tower) > 1:
trainer = AsyncMultiGPUTrainer
else:
logger.warn("Without GPU this model will never learn! CPU is only useful for debug.")
PREDICTOR_THREAD = 1
predict_tower, train_tower = [0], [0]
trainer = QueueInputTrainer
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = get_model_loader(args.load) config.session_init = get_model_loader(args.load)
config.tower = train_tower trainer = QueueInputTrainer if config.nr_tower == 1 else AsyncMultiGPUTrainer
config.predict_tower = predict_tower
trainer(config).train() trainer(config).train()
...@@ -41,16 +41,13 @@ class PredictorTowerHandle(object): ...@@ -41,16 +41,13 @@ class PredictorTowerHandle(object):
class PredictorFactory(object): class PredictorFactory(object):
""" Make predictors from :class:`ModelDesc`.""" """ Make predictors from :class:`ModelDesc`."""
def __init__(self, model, towers, vs_name=''): def __init__(self, model, vs_name=''):
""" """
Args: Args:
model (ModelDesc): model (ModelDesc):
towers (list[int]): list of available gpu id
vs_name (str): vs_name (str):
""" """
assert isinstance(towers, list), towers
self._model = model self._model = model
self._towers = towers
self._vs_name = vs_name self._vs_name = vs_name
self._names_built = {} self._names_built = {}
...@@ -82,12 +79,11 @@ class PredictorFactory(object): ...@@ -82,12 +79,11 @@ class PredictorFactory(object):
def get_predictor(self, input_names, output_names, tower): def get_predictor(self, input_names, output_names, tower):
""" """
Args: Args:
tower (int): need the kth tower (not the gpu id, but the id in TrainConfig.predict_tower) tower (int): use device '/gpu:{tower}' or use -1 for '/cpu:0'.
Returns: Returns:
an online predictor (which has to be used under a default session) an online predictor (which has to be used under a default session)
""" """
tower_name = 'towerp{}'.format(tower) tower_name = 'towerp{}'.format(tower)
tower = self._towers[tower]
device = '/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0' device = '/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0'
# use a previously-built tower # use a previously-built tower
# TODO check conflict with inference runner?? # TODO check conflict with inference runner??
......
...@@ -11,6 +11,7 @@ import tensorflow as tf ...@@ -11,6 +11,7 @@ import tensorflow as tf
from ..graph_builder.predictor_factory import PredictorFactory from ..graph_builder.predictor_factory import PredictorFactory
from .config import TrainConfig from .config import TrainConfig
from ..utils import logger from ..utils import logger
from ..utils.develop import deprecated
from ..callbacks import Callback, Callbacks, MaintainStepCounter from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..callbacks.monitor import Monitors, TrainingMonitor from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value
...@@ -234,27 +235,22 @@ class Trainer(object): ...@@ -234,27 +235,22 @@ class Trainer(object):
""" """
Args: Args:
input_names (list), output_names(list): list of names input_names (list), output_names(list): list of names
tower (int): return the predictor on the kth tower, defined by ``config.predict_tower``. tower (int): build the predictor on device '/gpu:{tower}' or use -1 for '/cpu:0'.
Returns: Returns:
an :class:`OnlinePredictor`. an :class:`OnlinePredictor`.
""" """
# TODO move the logic to factory? # TODO move the logic to factory?
nr_tower = len(self.config.predict_tower)
if nr_tower < tower:
logger.warn(
"Requested the {}th predictor but only have {} predict towers! "
"Predictors will be assigned to GPUs in round-robin.".format(tower, nr_tower))
tower = tower % nr_tower
return self.predictor_factory.get_predictor(input_names, output_names, tower) return self.predictor_factory.get_predictor(input_names, output_names, tower)
@property @property
def predictor_factory(self): def predictor_factory(self):
if not hasattr(self, '_predictor_factory'): if not hasattr(self, '_predictor_factory'):
self._predictor_factory = PredictorFactory( self._predictor_factory = PredictorFactory(
self.model, self.config.predict_tower, self.vs_name_for_predictor) self.model, self.vs_name_for_predictor)
return self._predictor_factory return self._predictor_factory
@deprecated("Please call `Trainer.get_predictor` to create them manually.")
def get_predictors(self, input_names, output_names, n): def get_predictors(self, input_names, output_names, n):
""" Return n predictors. """ """ Return n predictors. """
return [self.get_predictor(input_names, output_names, k) for k in range(n)] return [self.get_predictor(input_names, output_names, k) for k in range(n)]
...@@ -43,11 +43,12 @@ class TrainConfig(object): ...@@ -43,11 +43,12 @@ class TrainConfig(object):
callbacks (list): a list of :class:`Callback` to perform during training. callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are is only used to provide the defaults in addition to ``callbacks``. The defaults are
``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries(), RunUpdateOps()]``. The list of ``MovingAverageSummary()``, ``ProgressBar()``,
callbacks that will be used in the end are ``callbacks + extra_callbacks``. ``MergeAllSummaries()``, ``RunUpdateOps()``. The list of
callbacks that will be used in the end is ``callbacks + extra_callbacks``.
monitors (list): a list of :class:`TrainingMonitor`. monitors (list): a list of :class:`TrainingMonitor`.
Defaults to ``[TFEventWriter(), JSONWriter(), ScalarPrinter()]``. Defaults to ``TFEventWriter()``, ``JSONWriter()``, ``ScalarPrinter()``.
session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()` session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()`
with the config returned by :func:`tfutils.get_default_sess_config()`. with the config returned by :func:`tfutils.get_default_sess_config()`.
...@@ -59,8 +60,8 @@ class TrainConfig(object): ...@@ -59,8 +60,8 @@ class TrainConfig(object):
Defaults to the input data size. Defaults to the input data size.
max_epoch (int): maximum number of epoch to run training. max_epoch (int): maximum number of epoch to run training.
nr_tower (int): number of training towers. nr_tower (int): number of training towers, used by multigpu trainers.
tower (list of int): list of training towers in relative id. tower (list of int): list of training towers in relative GPU id.
predict_tower (list of int): list of prediction towers in their relative gpu id. Use -1 for cpu. predict_tower (list of int): list of prediction towers in their relative gpu id. Use -1 for cpu.
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment