Commit d9817a56 authored by Yuxin Wu's avatar Yuxin Wu

deprecate TrainConfig(predict_tower). The device should be passed directly to InferenceRunner

parent da98e447
......@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+ [2017/10/18]
`TrainConfig(predict_tower)` was deprecated. You can set the inference device directly when creating the `InferenceRunner` callback.
+ [2017/10/12](https://github.com/ppwwyyxx/tensorpack/commit/7e963996f615b85f7459455596b4ee9bbd0bce8e).
`tensorpack.RL` was deprecated. The RL examples are written with OpenAI gym interface instead.
+ [2017/10/10](https://github.com/ppwwyyxx/tensorpack/commit/7d40e049691d92018f50dc7d45bba5e8b140becc).
......
......@@ -101,7 +101,7 @@ class InferenceRunner(InferenceRunnerBase):
A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`.
"""
def __init__(self, input, infs, tower_name='InferenceTower', extra_hooks=None):
def __init__(self, input, infs, tower_name='InferenceTower', device=0, extra_hooks=None):
"""
Args:
input (InputSource or DataFlow): The :class:`InputSource` to run
......@@ -109,11 +109,13 @@ class InferenceRunner(InferenceRunnerBase):
infs (list): a list of :class:`Inferencer` instances.
tower_name (str): the name scope of the tower to build. Need to set a
different one if multiple InferenceRunner are used.
gpu (int): the device to use
"""
if isinstance(input, DataFlow):
input = FeedInput(input, infinite=False)
assert isinstance(input, InputSource), input
self._tower_name = tower_name
self._device = device
super(InferenceRunner, self).__init__(
input, infs, extra_hooks=extra_hooks)
......@@ -127,8 +129,11 @@ class InferenceRunner(InferenceRunnerBase):
# old Trainer API
assert self.trainer.model is not None
# Use predict_tower in train config. either gpuid or -1
tower_id = self.trainer._config.predict_tower[0]
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0'
if self.trainer._config.predict_tower is not None:
device = self.trainer._config.predict_tower[0]
else:
device = self._device
device = '/gpu:{}'.format(device) if device >= 0 else '/cpu:0'
input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc())
......
......@@ -13,6 +13,7 @@ from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit)
from ..tfutils.sesscreate import NewSessionCreator
from ..input_source import InputSource
from ..utils.develop import log_deprecated
__all__ = ['TrainConfig']
......@@ -128,8 +129,9 @@ class TrainConfig(object):
self.tower = tower
predict_tower = kwargs.pop('predict_tower', None)
if predict_tower is None:
predict_tower = [0]
if predict_tower is not None:
log_deprecated("TrainConfig(predict_tower=)",
"InferenceRunner now accepts a 'device' argument.", "2017-12-31")
self.predict_tower = predict_tower
if isinstance(self.predict_tower, int):
self.predict_tower = [self.predict_tower]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment