Commit 9dba9893 authored by Yuxin Wu's avatar Yuxin Wu

Let predictor share variable scope on each specific GPU

parent 30240eae
...@@ -126,6 +126,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -126,6 +126,7 @@ class InferenceRunner(InferenceRunnerBase):
assert isinstance(input, InputSource), input assert isinstance(input, InputSource), input
assert not isinstance(input, StagingInput), input assert not isinstance(input, StagingInput), input
self._tower_name = tower_name self._tower_name = tower_name
self._device_id = device
self._device = _device_from_int(device) self._device = _device_from_int(device)
super(InferenceRunner, self).__init__(input, infs) super(InferenceRunner, self).__init__(input, infs)
...@@ -139,11 +140,13 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -139,11 +140,13 @@ class InferenceRunner(InferenceRunnerBase):
tower_func = self.trainer.tower_func tower_func = self.trainer.tower_func
input_callbacks = self._input_source.setup(tower_func.inputs_desc) input_callbacks = self._input_source.setup(tower_func.inputs_desc)
logger.info("[InferenceRunner] Building tower '{}' on device {} ...".format(self._tower_name, self._device)) vs_name = self.trainer._vs_name_for_predictor(self._device_id)
logger.info("[InferenceRunner] Building tower '{}' on device {} {}...".format(
self._tower_name, self._device,
"with variable scope '{}'".format(vs_name) if vs_name else ''))
with tf.variable_scope(tf.get_variable_scope(), reuse=True), \ with tf.variable_scope(tf.get_variable_scope(), reuse=True), \
tf.device(self._device), \ tf.device(self._device), \
PredictTowerContext( PredictTowerContext(self._tower_name, vs_name=vs_name):
self._tower_name, vs_name=self.trainer._main_tower_vs_name):
tower_func(*self._input_source.get_input_tensors()) tower_func(*self._input_source.get_input_tensors())
self._tower_handle = tower_func.towers[-1] self._tower_handle = tower_func.towers[-1]
...@@ -211,8 +214,13 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -211,8 +214,13 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
input_callbacks = self._input_source.setup(tower_func.inputs_desc) input_callbacks = self._input_source.setup(tower_func.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for idx, dev in enumerate(self._devices): for idx, dev in enumerate(self._devices):
vs_name = self._vs_name_for_predictor(idx)
with tf.device(dev), PredictTowerContext( with tf.device(dev), PredictTowerContext(
self._tower_names[idx], vs_name=self.trainer._main_tower_vs_name): self._tower_names[idx], vs_name=vs_name):
logger.info("[InferenceRunner] Building tower '{}' on device {} {}...".format(
self._tower_names[idx], dev,
"with variable scope '{}'".format(vs_name) if vs_name else ''))
# TODO log for tower creation, here or in tower.py?
tower_func(*self._input_source.get_input_tensors()) tower_func(*self._input_source.get_input_tensors())
self._handles.append(tower_func.towers[-1]) self._handles.append(tower_func.towers[-1])
......
...@@ -9,8 +9,6 @@ import six ...@@ -9,8 +9,6 @@ import six
from ..tfutils.common import get_tensors_by_names from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import PredictTowerContext from ..tfutils.tower import PredictTowerContext
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from ..utils.develop import log_deprecated
from ..utils.utils import execute_only_once
__all__ = ['PredictorBase', 'AsyncPredictorBase', __all__ = ['PredictorBase', 'AsyncPredictorBase',
'OnlinePredictor', 'OfflinePredictor', 'OnlinePredictor', 'OfflinePredictor',
...@@ -27,7 +25,7 @@ class PredictorBase(object): ...@@ -27,7 +25,7 @@ class PredictorBase(object):
or just outputs or just outputs
""" """
def __call__(self, *args): def __call__(self, *dp):
""" """
Call the predictor on some inputs. Call the predictor on some inputs.
...@@ -38,15 +36,6 @@ class PredictorBase(object): ...@@ -38,15 +36,6 @@ class PredictorBase(object):
predictor(e1, e2) predictor(e1, e2)
""" """
if len(args) == 1 and isinstance(args[0], (list, tuple)):
dp = args[0] # backward-compatibility
if execute_only_once():
log_deprecated(
"Calling a predictor with one datapoint",
"Call it with positional arguments instead!",
"2018-3-1")
else:
dp = args
output = self._do_call(dp) output = self._do_call(dp)
if self.return_input: if self.return_input:
return (dp, output) return (dp, output)
...@@ -94,6 +83,12 @@ class OnlinePredictor(PredictorBase): ...@@ -94,6 +83,12 @@ class OnlinePredictor(PredictorBase):
""" """
ACCEPT_OPTIONS = False ACCEPT_OPTIONS = False
""" See Session.make_callable """
sess = None
"""
The tf.Session object associated with this predictor.
"""
def __init__(self, input_tensors, output_tensors, def __init__(self, input_tensors, output_tensors,
return_input=False, sess=None): return_input=False, sess=None):
...@@ -104,6 +99,7 @@ class OnlinePredictor(PredictorBase): ...@@ -104,6 +99,7 @@ class OnlinePredictor(PredictorBase):
return_input (bool): same as :attr:`PredictorBase.return_input`. return_input (bool): same as :attr:`PredictorBase.return_input`.
sess (tf.Session): the session this predictor runs in. If None, sess (tf.Session): the session this predictor runs in. If None,
will use the default session at the first call. will use the default session at the first call.
Note that in TensorFlow, default session is thread-local.
""" """
self.return_input = return_input self.return_input = return_input
self.input_tensors = input_tensors self.input_tensors = input_tensors
...@@ -123,6 +119,7 @@ class OnlinePredictor(PredictorBase): ...@@ -123,6 +119,7 @@ class OnlinePredictor(PredictorBase):
"{} != {}".format(len(dp), len(self.input_tensors)) "{} != {}".format(len(dp), len(self.input_tensors))
if self.sess is None: if self.sess is None:
self.sess = tf.get_default_session() self.sess = tf.get_default_session()
assert self.sess is not None, "Predictor isn't called under a default session!"
if self._callable is None: if self._callable is None:
self._callable = self.sess.make_callable( self._callable = self.sess.make_callable(
......
...@@ -280,6 +280,9 @@ class TowerTensorHandles(object): ...@@ -280,6 +280,9 @@ class TowerTensorHandles(object):
self._handles = handles self._handles = handles
self._name_to_handle = {k.ns_name: k for k in handles} self._name_to_handle = {k.ns_name: k for k in handles}
def __len__(self):
return len(self._handles)
def __getitem__(self, name_or_index): def __getitem__(self, name_or_index):
""" """
Args: Args:
......
...@@ -6,6 +6,8 @@ import six ...@@ -6,6 +6,8 @@ import six
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from ..utils.argtools import call_only_once, memoized from ..utils.argtools import call_only_once, memoized
from ..utils.develop import HIDE_DOC
from ..utils import logger
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from ..predict.base import OnlinePredictor from ..predict.base import OnlinePredictor
...@@ -29,6 +31,11 @@ class TowerTrainer(Trainer): ...@@ -29,6 +31,11 @@ class TowerTrainer(Trainer):
""" """
_tower_func = None _tower_func = None
_predictors = []
"""
List of OnlinePredictor ever created for this trainer.
It is maintained for internal use.
"""
@call_only_once @call_only_once
def _set_tower_func(self, tower_func): def _set_tower_func(self, tower_func):
...@@ -93,7 +100,8 @@ class TowerTrainer(Trainer): ...@@ -93,7 +100,8 @@ class TowerTrainer(Trainer):
""" """
assert self.tower_func is not None, "Must set tower_func on the trainer to use get_predictor()!" assert self.tower_func is not None, "Must set tower_func on the trainer to use get_predictor()!"
tower_name = 'tower-pred-{}'.format(device) if device >= 0 else 'tower-pred-cpu' tower_name = 'tower-pred-{}'.format(device) if device >= 0 else 'tower-pred-cpu'
device = '/gpu:{}'.format(device) if device >= 0 else '/cpu:0' device_id = device
device = '/gpu:{}'.format(device_id) if device_id >= 0 else '/cpu:0'
try: try:
tower = self.tower_func.towers[tower_name] tower = self.tower_func.towers[tower_name]
...@@ -105,22 +113,36 @@ class TowerTrainer(Trainer): ...@@ -105,22 +113,36 @@ class TowerTrainer(Trainer):
input = PlaceholderInput() input = PlaceholderInput()
input.setup(self.inputs_desc) input.setup(self.inputs_desc)
vs_name = self._vs_name_for_predictor(device_id)
with tf.variable_scope(tf.get_variable_scope(), reuse=True), \ with tf.variable_scope(tf.get_variable_scope(), reuse=True), \
tf.device(device), PredictTowerContext( tf.device(device), PredictTowerContext(
tower_name, vs_name=self._main_tower_vs_name): tower_name, vs_name=vs_name):
logger.info("Building graph for predict tower '{}' on device {} {}...".format(
tower_name, device,
"with variable scope '{}'".format(vs_name) if vs_name else ''))
self.tower_func(*input.get_input_tensors()) self.tower_func(*input.get_input_tensors())
tower = self.tower_func.towers[tower_name] tower = self.tower_func.towers[tower_name]
input_tensors = tower.get_tensors(input_names) input_tensors = tower.get_tensors(input_names)
output_tensors = tower.get_tensors(output_names) output_tensors = tower.get_tensors(output_names)
return OnlinePredictor(input_tensors, output_tensors) predictor = OnlinePredictor(input_tensors, output_tensors)
self._predictors.append(predictor)
return predictor
@property @HIDE_DOC
def _main_tower_vs_name(self): @call_only_once
""" def initialize(self, session_creator, session_init):
The vs name for the "main" copy of the model, super(TowerTrainer, self).initialize(session_creator, session_init)
to be used to build predictors. # Predictors are created before creating the session, so they don't have an associated session.
""" for pred in self._predictors:
return "" pred.sess = self.sess
def _vs_name_for_predictor(self, device):
towers = self.towers.training()
available_ids = list(range(len(towers)))
if device in available_ids:
return towers[device].vs_name
else:
return towers[0].vs_name
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment