Commit 098cbd97 authored by Yuxin Wu's avatar Yuxin Wu

Use custom name for predictor tower

parent ab507874
...@@ -16,7 +16,6 @@ from six.moves import range ...@@ -16,7 +16,6 @@ from six.moves import range
from ..utils import logger, get_tqdm_kwargs from ..utils import logger, get_tqdm_kwargs
from ..utils.develop import deprecated from ..utils.develop import deprecated
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..tfutils.tower import TowerContext
from ..graph_builder.input_source_base import InputSource from ..graph_builder.input_source_base import InputSource
from ..graph_builder.input_source import ( from ..graph_builder.input_source import (
...@@ -90,7 +89,10 @@ class InferenceRunnerBase(Callback): ...@@ -90,7 +89,10 @@ class InferenceRunnerBase(Callback):
# Use predict_tower in train config. either gpuid or -1 # Use predict_tower in train config. either gpuid or -1
tower_id = self.trainer.config.predict_tower[0] tower_id = self.trainer.config.predict_tower[0]
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0' device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0'
tower_name = TowerContext.get_predict_tower_name(tower_id, prefix=self._prefix)
tower_name = 'InferenceRunner'
if self._prefix:
tower_name += '_' + self._prefix
self._input_source.setup(self.trainer.model.get_inputs_desc()) self._input_source.setup(self.trainer.model.get_inputs_desc())
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
...@@ -163,9 +165,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -163,9 +165,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
input (DataParallelFeedInput or DataFlow) input (DataParallelFeedInput or DataFlow)
gpus (list[int]): list of GPU id gpus (list[int]): list of GPU id
""" """
self._tower_names = ['InferenceRunner{}'.format(k) for k in range(len(gpus))]
if isinstance(input, DataFlow): if isinstance(input, DataFlow):
tower_names = [TowerContext.get_predict_tower_name(k) for k in range(len(gpus))] input = DataParallelFeedInput(input, self._tower_names)
input = DataParallelFeedInput(input, tower_names)
assert isinstance(input, DataParallelFeedInput), input assert isinstance(input, DataParallelFeedInput), input
super(DataParallelInferenceRunner, self).__init__(input, infs) super(DataParallelInferenceRunner, self).__init__(input, infs)
...@@ -175,8 +177,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -175,8 +177,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self._input_source.setup(self.trainer.model.get_inputs_desc()) self._input_source.setup(self.trainer.model.get_inputs_desc())
self._handles = [] self._handles = []
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for t in self._gpus: for idx, t in enumerate(self._gpus):
tower_name = TowerContext.get_predict_tower_name(t, prefix=self._prefix) tower_name = self._tower_names[idx]
device = '/gpu:{}'.format(t) device = '/gpu:{}'.format(t)
self._handles.append( self._handles.append(
self.trainer.predictor_factory.build( self.trainer.predictor_factory.build(
......
...@@ -55,7 +55,7 @@ class PredictorFactory(object): ...@@ -55,7 +55,7 @@ class PredictorFactory(object):
self._names_built = {} self._names_built = {}
def build(self, tower_name, device, input=None): def build(self, tower_name, device, input=None):
logger.info("Building predictor graph {} on device {} ...".format(tower_name, device)) logger.info("Building predictor tower '{}' on device {} ...".format(tower_name, device))
assert tower_name not in self._names_built assert tower_name not in self._names_built
with tf.device(device), \ with tf.device(device), \
...@@ -83,12 +83,12 @@ class PredictorFactory(object): ...@@ -83,12 +83,12 @@ class PredictorFactory(object):
Returns: Returns:
an online predictor (which has to be used under the default session) an online predictor (which has to be used under the default session)
""" """
tower_name = 'towerp{}'.format(tower)
tower = self._towers[tower] tower = self._towers[tower]
device = '/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0' device = '/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0'
tower_name = TowerContext.get_predict_tower_name(max(tower, 0)) # XXX
# use a previously-built tower # use a previously-built tower
# TODO conflict with inference runner?? # TODO conflict with inference runner??
if not self.has_built(tower_name): if tower_name not in self._names_built:
with tf.variable_scope(self._vs_name, reuse=True): with tf.variable_scope(self._vs_name, reuse=True):
handle = self.build(tower_name, device) handle = self.build(tower_name, device)
else: else:
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import tensorflow as tf import tensorflow as tf
from ..utils import logger from ..utils import logger
from ..tfutils import TowerContext
from ..graph_builder.predictor_factory import PredictorFactory from ..graph_builder.predictor_factory import PredictorFactory
from ..graph_builder.input_source import PlaceholderInput from ..graph_builder.input_source import PlaceholderInput
from .base import OnlinePredictor from .base import OnlinePredictor
...@@ -31,10 +30,9 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -31,10 +30,9 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
handles = [] handles = []
factory = PredictorFactory(config.model, towers) factory = PredictorFactory(config.model, towers)
for idx, t in enumerate(towers): for idx, t in enumerate(towers):
tower_name = TowerContext.get_predict_tower_name(t) tower_name = 'tower' + str(t)
device = '/gpu:' + str(t) device = '/gpu:' + str(t)
# TODO smarter TowerContext?
with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0): with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0):
handles.append(factory.build(tower_name, device)) handles.append(factory.build(tower_name, device))
...@@ -91,7 +89,7 @@ class DataParallelOfflinePredictor(OnlinePredictor): ...@@ -91,7 +89,7 @@ class DataParallelOfflinePredictor(OnlinePredictor):
factory = PredictorFactory(config.model, towers) factory = PredictorFactory(config.model, towers)
for idx, t in enumerate(towers): for idx, t in enumerate(towers):
tower_name = TowerContext.get_predict_tower_name(t) tower_name = 'tower' + str(t)
device = '/gpu:' + str(t) device = '/gpu:' + str(t)
input = PlaceholderInput(tower_name + '/') input = PlaceholderInput(tower_name + '/')
input.setup(config.model.get_inputs_desc()) input.setup(config.model.get_inputs_desc())
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
from ..utils.naming import PREDICT_TOWER
__all__ = ['get_current_tower_context', 'TowerContext'] __all__ = ['get_current_tower_context', 'TowerContext']
...@@ -24,10 +23,6 @@ class TowerContext(object): ...@@ -24,10 +23,6 @@ class TowerContext(object):
vs_name (str): Open a variable scope with this name, if given. vs_name (str): Open a variable scope with this name, if given.
""" """
self._name = tower_name self._name = tower_name
if is_training is None:
# TODO remove this
is_training = not self._name.startswith(PREDICT_TOWER)
self._is_training = bool(is_training) self._is_training = bool(is_training)
if not self._is_training: if not self._is_training:
...@@ -84,21 +79,6 @@ class TowerContext(object): ...@@ -84,21 +79,6 @@ class TowerContext(object):
def index(self): def index(self):
return self._index return self._index
# TODO something similar for training
@staticmethod
def get_predict_tower_name(towerid=0, prefix=''):
"""
Args:
towerid(int): an integer, the id of this predict tower, usually
used to choose the GPU id.
prefix(str): an alphanumeric prefix.
Returns:
str: the final tower name used to create a predict tower.
Currently it is ``PREDICT_TOWER + prefix + towerid``.
"""
assert prefix == '' or prefix.isalnum()
return PREDICT_TOWER + prefix + str(towerid)
def __enter__(self): def __enter__(self):
global _CurrentTowerContext global _CurrentTowerContext
assert _CurrentTowerContext is None, \ assert _CurrentTowerContext is None, \
......
...@@ -29,6 +29,7 @@ def get_savename_from_varname( ...@@ -29,6 +29,7 @@ def get_savename_from_varname(
str: the name used to save the variable str: the name used to save the variable
""" """
name = varname name = varname
# TODO PREDICT_TOWER is not used anymore
if PREDICT_TOWER in name: if PREDICT_TOWER in name:
logger.error("No variable under '{}' name scope should be saved!".format(PREDICT_TOWER)) logger.error("No variable under '{}' name scope should be saved!".format(PREDICT_TOWER))
# don't overwrite anything in the current prediction graph # don't overwrite anything in the current prediction graph
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment