Commit d9817a56 authored by Yuxin Wu's avatar Yuxin Wu

deprecate TrainConfig(predict_tower). The device should be passed directly to InferenceRunner

parent da98e447
...@@ -8,6 +8,8 @@ so you won't need to look at here very often. ...@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version. Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changed APIs before 1.0 and those are not listed here. TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+ [2017/10/18]
`TrainConfig(predict_tower)` was deprecated. You can set the inference device directly when creating the `InferenceRunner` callback.
+ [2017/10/12](https://github.com/ppwwyyxx/tensorpack/commit/7e963996f615b85f7459455596b4ee9bbd0bce8e). + [2017/10/12](https://github.com/ppwwyyxx/tensorpack/commit/7e963996f615b85f7459455596b4ee9bbd0bce8e).
`tensorpack.RL` was deprecated. The RL examples are written with OpenAI gym interface instead. `tensorpack.RL` was deprecated. The RL examples are written with OpenAI gym interface instead.
+ [2017/10/10](https://github.com/ppwwyyxx/tensorpack/commit/7d40e049691d92018f50dc7d45bba5e8b140becc). + [2017/10/10](https://github.com/ppwwyyxx/tensorpack/commit/7d40e049691d92018f50dc7d45bba5e8b140becc).
......
...@@ -101,7 +101,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -101,7 +101,7 @@ class InferenceRunner(InferenceRunnerBase):
A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`. A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`.
""" """
def __init__(self, input, infs, tower_name='InferenceTower', extra_hooks=None): def __init__(self, input, infs, tower_name='InferenceTower', device=0, extra_hooks=None):
""" """
Args: Args:
input (InputSource or DataFlow): The :class:`InputSource` to run input (InputSource or DataFlow): The :class:`InputSource` to run
...@@ -109,11 +109,13 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -109,11 +109,13 @@ class InferenceRunner(InferenceRunnerBase):
infs (list): a list of :class:`Inferencer` instances. infs (list): a list of :class:`Inferencer` instances.
tower_name (str): the name scope of the tower to build. Need to set a tower_name (str): the name scope of the tower to build. Need to set a
different one if multiple InferenceRunner are used. different one if multiple InferenceRunner are used.
gpu (int): the device to use
""" """
if isinstance(input, DataFlow): if isinstance(input, DataFlow):
input = FeedInput(input, infinite=False) input = FeedInput(input, infinite=False)
assert isinstance(input, InputSource), input assert isinstance(input, InputSource), input
self._tower_name = tower_name self._tower_name = tower_name
self._device = device
super(InferenceRunner, self).__init__( super(InferenceRunner, self).__init__(
input, infs, extra_hooks=extra_hooks) input, infs, extra_hooks=extra_hooks)
...@@ -127,8 +129,11 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -127,8 +129,11 @@ class InferenceRunner(InferenceRunnerBase):
# old Trainer API # old Trainer API
assert self.trainer.model is not None assert self.trainer.model is not None
# Use predict_tower in train config. either gpuid or -1 # Use predict_tower in train config. either gpuid or -1
tower_id = self.trainer._config.predict_tower[0] if self.trainer._config.predict_tower is not None:
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0' device = self.trainer._config.predict_tower[0]
else:
device = self._device
device = '/gpu:{}'.format(device) if device >= 0 else '/cpu:0'
input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc()) input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc())
......
...@@ -13,6 +13,7 @@ from ..tfutils import (JustCurrentSession, ...@@ -13,6 +13,7 @@ from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit) get_default_sess_config, SessionInit)
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..input_source import InputSource from ..input_source import InputSource
from ..utils.develop import log_deprecated
__all__ = ['TrainConfig'] __all__ = ['TrainConfig']
...@@ -128,8 +129,9 @@ class TrainConfig(object): ...@@ -128,8 +129,9 @@ class TrainConfig(object):
self.tower = tower self.tower = tower
predict_tower = kwargs.pop('predict_tower', None) predict_tower = kwargs.pop('predict_tower', None)
if predict_tower is None: if predict_tower is not None:
predict_tower = [0] log_deprecated("TrainConfig(predict_tower=)",
"InferenceRunner now accepts a 'device' argument.", "2017-12-31")
self.predict_tower = predict_tower self.predict_tower = predict_tower
if isinstance(self.predict_tower, int): if isinstance(self.predict_tower, int):
self.predict_tower = [self.predict_tower] self.predict_tower = [self.predict_tower]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment