Commit 9dba9893 authored by Yuxin Wu's avatar Yuxin Wu

Let predictor share variable scope on each specific GPU

parent 30240eae
......@@ -126,6 +126,7 @@ class InferenceRunner(InferenceRunnerBase):
assert isinstance(input, InputSource), input
assert not isinstance(input, StagingInput), input
self._tower_name = tower_name
self._device_id = device
self._device = _device_from_int(device)
super(InferenceRunner, self).__init__(input, infs)
......@@ -139,11 +140,13 @@ class InferenceRunner(InferenceRunnerBase):
tower_func = self.trainer.tower_func
input_callbacks = self._input_source.setup(tower_func.inputs_desc)
logger.info("[InferenceRunner] Building tower '{}' on device {} ...".format(self._tower_name, self._device))
vs_name = self.trainer._vs_name_for_predictor(self._device_id)
logger.info("[InferenceRunner] Building tower '{}' on device {} {}...".format(
self._tower_name, self._device,
"with variable scope '{}'".format(vs_name) if vs_name else ''))
with tf.variable_scope(tf.get_variable_scope(), reuse=True), \
tf.device(self._device), \
PredictTowerContext(
self._tower_name, vs_name=self.trainer._main_tower_vs_name):
PredictTowerContext(self._tower_name, vs_name=vs_name):
tower_func(*self._input_source.get_input_tensors())
self._tower_handle = tower_func.towers[-1]
......@@ -211,8 +214,13 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
input_callbacks = self._input_source.setup(tower_func.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for idx, dev in enumerate(self._devices):
vs_name = self._vs_name_for_predictor(idx)
with tf.device(dev), PredictTowerContext(
self._tower_names[idx], vs_name=self.trainer._main_tower_vs_name):
self._tower_names[idx], vs_name=vs_name):
logger.info("[InferenceRunner] Building tower '{}' on device {} {}...".format(
self._tower_names[idx], dev,
"with variable scope '{}'".format(vs_name) if vs_name else ''))
# TODO log for tower creation, here or in tower.py?
tower_func(*self._input_source.get_input_tensors())
self._handles.append(tower_func.towers[-1])
......
......@@ -9,8 +9,6 @@ import six
from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import PredictTowerContext
from ..input_source import PlaceholderInput
from ..utils.develop import log_deprecated
from ..utils.utils import execute_only_once
__all__ = ['PredictorBase', 'AsyncPredictorBase',
'OnlinePredictor', 'OfflinePredictor',
......@@ -27,7 +25,7 @@ class PredictorBase(object):
or just outputs
"""
def __call__(self, *args):
def __call__(self, *dp):
"""
Call the predictor on some inputs.
......@@ -38,15 +36,6 @@ class PredictorBase(object):
predictor(e1, e2)
"""
if len(args) == 1 and isinstance(args[0], (list, tuple)):
dp = args[0] # backward-compatibility
if execute_only_once():
log_deprecated(
"Calling a predictor with one datapoint",
"Call it with positional arguments instead!",
"2018-3-1")
else:
dp = args
output = self._do_call(dp)
if self.return_input:
return (dp, output)
......@@ -94,6 +83,12 @@ class OnlinePredictor(PredictorBase):
"""
ACCEPT_OPTIONS = False
""" See Session.make_callable """
sess = None
"""
The tf.Session object associated with this predictor.
"""
def __init__(self, input_tensors, output_tensors,
return_input=False, sess=None):
......@@ -104,6 +99,7 @@ class OnlinePredictor(PredictorBase):
return_input (bool): same as :attr:`PredictorBase.return_input`.
sess (tf.Session): the session this predictor runs in. If None,
will use the default session at the first call.
Note that in TensorFlow, default session is thread-local.
"""
self.return_input = return_input
self.input_tensors = input_tensors
......@@ -123,6 +119,7 @@ class OnlinePredictor(PredictorBase):
"{} != {}".format(len(dp), len(self.input_tensors))
if self.sess is None:
self.sess = tf.get_default_session()
assert self.sess is not None, "Predictor isn't called under a default session!"
if self._callable is None:
self._callable = self.sess.make_callable(
......
......@@ -280,6 +280,9 @@ class TowerTensorHandles(object):
self._handles = handles
self._name_to_handle = {k.ns_name: k for k in handles}
def __len__(self):
return len(self._handles)
def __getitem__(self, name_or_index):
"""
Args:
......
......@@ -6,6 +6,8 @@ import six
from abc import abstractmethod, ABCMeta
from ..utils.argtools import call_only_once, memoized
from ..utils.develop import HIDE_DOC
from ..utils import logger
from ..input_source import PlaceholderInput
from ..predict.base import OnlinePredictor
......@@ -29,6 +31,11 @@ class TowerTrainer(Trainer):
"""
_tower_func = None
_predictors = []
"""
List of OnlinePredictor ever created for this trainer.
It is maintained for internal use.
"""
@call_only_once
def _set_tower_func(self, tower_func):
......@@ -93,7 +100,8 @@ class TowerTrainer(Trainer):
"""
assert self.tower_func is not None, "Must set tower_func on the trainer to use get_predictor()!"
tower_name = 'tower-pred-{}'.format(device) if device >= 0 else 'tower-pred-cpu'
device = '/gpu:{}'.format(device) if device >= 0 else '/cpu:0'
device_id = device
device = '/gpu:{}'.format(device_id) if device_id >= 0 else '/cpu:0'
try:
tower = self.tower_func.towers[tower_name]
......@@ -105,22 +113,36 @@ class TowerTrainer(Trainer):
input = PlaceholderInput()
input.setup(self.inputs_desc)
vs_name = self._vs_name_for_predictor(device_id)
with tf.variable_scope(tf.get_variable_scope(), reuse=True), \
tf.device(device), PredictTowerContext(
tower_name, vs_name=self._main_tower_vs_name):
tower_name, vs_name=vs_name):
logger.info("Building graph for predict tower '{}' on device {} {}...".format(
tower_name, device,
"with variable scope '{}'".format(vs_name) if vs_name else ''))
self.tower_func(*input.get_input_tensors())
tower = self.tower_func.towers[tower_name]
input_tensors = tower.get_tensors(input_names)
output_tensors = tower.get_tensors(output_names)
return OnlinePredictor(input_tensors, output_tensors)
predictor = OnlinePredictor(input_tensors, output_tensors)
self._predictors.append(predictor)
return predictor
@property
def _main_tower_vs_name(self):
"""
The vs name for the "main" copy of the model,
to be used to build predictors.
"""
return ""
@HIDE_DOC
@call_only_once
def initialize(self, session_creator, session_init):
super(TowerTrainer, self).initialize(session_creator, session_init)
# Predictors are created before creating the session, so they don't have an associated session.
for pred in self._predictors:
pred.sess = self.sess
def _vs_name_for_predictor(self, device):
towers = self.towers.training()
available_ids = list(range(len(towers)))
if device in available_ids:
return towers[device].vs_name
else:
return towers[0].vs_name
@six.add_metaclass(ABCMeta)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment