Commit 098cbd97 authored by Yuxin Wu's avatar Yuxin Wu

Use custom name for predictor tower

parent ab507874
......@@ -16,7 +16,6 @@ from six.moves import range
from ..utils import logger, get_tqdm_kwargs
from ..utils.develop import deprecated
from ..dataflow import DataFlow
from ..tfutils.tower import TowerContext
from ..graph_builder.input_source_base import InputSource
from ..graph_builder.input_source import (
......@@ -90,7 +89,10 @@ class InferenceRunnerBase(Callback):
# Use predict_tower in train config. either gpuid or -1
tower_id = self.trainer.config.predict_tower[0]
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0'
tower_name = TowerContext.get_predict_tower_name(tower_id, prefix=self._prefix)
tower_name = 'InferenceRunner'
if self._prefix:
tower_name += '_' + self._prefix
self._input_source.setup(self.trainer.model.get_inputs_desc())
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
......@@ -163,9 +165,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
input (DataParallelFeedInput or DataFlow)
gpus (list[int]): list of GPU id
"""
self._tower_names = ['InferenceRunner{}'.format(k) for k in range(len(gpus))]
if isinstance(input, DataFlow):
tower_names = [TowerContext.get_predict_tower_name(k) for k in range(len(gpus))]
input = DataParallelFeedInput(input, tower_names)
input = DataParallelFeedInput(input, self._tower_names)
assert isinstance(input, DataParallelFeedInput), input
super(DataParallelInferenceRunner, self).__init__(input, infs)
......@@ -175,8 +177,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self._input_source.setup(self.trainer.model.get_inputs_desc())
self._handles = []
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for t in self._gpus:
tower_name = TowerContext.get_predict_tower_name(t, prefix=self._prefix)
for idx, t in enumerate(self._gpus):
tower_name = self._tower_names[idx]
device = '/gpu:{}'.format(t)
self._handles.append(
self.trainer.predictor_factory.build(
......
......@@ -55,7 +55,7 @@ class PredictorFactory(object):
self._names_built = {}
def build(self, tower_name, device, input=None):
logger.info("Building predictor graph {} on device {} ...".format(tower_name, device))
logger.info("Building predictor tower '{}' on device {} ...".format(tower_name, device))
assert tower_name not in self._names_built
with tf.device(device), \
......@@ -83,12 +83,12 @@ class PredictorFactory(object):
Returns:
an online predictor (which has to be used under the default session)
"""
tower_name = 'towerp{}'.format(tower)
tower = self._towers[tower]
device = '/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0'
tower_name = TowerContext.get_predict_tower_name(max(tower, 0)) # XXX
# use a previously-built tower
# TODO conflict with inference runner??
if not self.has_built(tower_name):
if tower_name not in self._names_built:
with tf.variable_scope(self._vs_name, reuse=True):
handle = self.build(tower_name, device)
else:
......
......@@ -5,7 +5,6 @@
import tensorflow as tf
from ..utils import logger
from ..tfutils import TowerContext
from ..graph_builder.predictor_factory import PredictorFactory
from ..graph_builder.input_source import PlaceholderInput
from .base import OnlinePredictor
......@@ -31,10 +30,9 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
handles = []
factory = PredictorFactory(config.model, towers)
for idx, t in enumerate(towers):
tower_name = TowerContext.get_predict_tower_name(t)
tower_name = 'tower' + str(t)
device = '/gpu:' + str(t)
# TODO smarter TowerContext?
with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0):
handles.append(factory.build(tower_name, device))
......@@ -91,7 +89,7 @@ class DataParallelOfflinePredictor(OnlinePredictor):
factory = PredictorFactory(config.model, towers)
for idx, t in enumerate(towers):
tower_name = TowerContext.get_predict_tower_name(t)
tower_name = 'tower' + str(t)
device = '/gpu:' + str(t)
input = PlaceholderInput(tower_name + '/')
input.setup(config.model.get_inputs_desc())
......
......@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from ..utils.naming import PREDICT_TOWER
__all__ = ['get_current_tower_context', 'TowerContext']
......@@ -24,10 +23,6 @@ class TowerContext(object):
vs_name (str): Open a variable scope with this name, if given.
"""
self._name = tower_name
if is_training is None:
# TODO remove this
is_training = not self._name.startswith(PREDICT_TOWER)
self._is_training = bool(is_training)
if not self._is_training:
......@@ -84,21 +79,6 @@ class TowerContext(object):
def index(self):
return self._index
# TODO something similar for training
@staticmethod
def get_predict_tower_name(towerid=0, prefix=''):
"""
Args:
towerid(int): an integer, the id of this predict tower, usually
used to choose the GPU id.
prefix(str): an alphanumeric prefix.
Returns:
str: the final tower name used to create a predict tower.
Currently it is ``PREDICT_TOWER + prefix + towerid``.
"""
assert prefix == '' or prefix.isalnum()
return PREDICT_TOWER + prefix + str(towerid)
def __enter__(self):
global _CurrentTowerContext
assert _CurrentTowerContext is None, \
......
......@@ -29,6 +29,7 @@ def get_savename_from_varname(
str: the name used to save the variable
"""
name = varname
# TODO PREDICT_TOWER is not used anymore
if PREDICT_TOWER in name:
logger.error("No variable under '{}' name scope should be saved!".format(PREDICT_TOWER))
# don't overwrite anything in the current prediction graph
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment