Commit 8a80b036 authored by Yuxin Wu's avatar Yuxin Wu

remove PredictorTowerBuilder

parent 7699fd9b
...@@ -7,17 +7,11 @@ from abc import abstractmethod, ABCMeta ...@@ -7,17 +7,11 @@ from abc import abstractmethod, ABCMeta
import tensorflow as tf import tensorflow as tf
import six import six
from ..utils import logger from ..tfutils.common import get_tensors_by_names
from ..utils.argtools import memoized
from ..utils.naming import TOWER_FREEZE_KEYS
from ..tfutils.common import get_tensors_by_names, get_op_tensor_name
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..tfutils.collection import freeze_collection
__all__ = ['PredictorBase', 'AsyncPredictorBase', __all__ = ['PredictorBase', 'AsyncPredictorBase',
'OnlinePredictor', 'OfflinePredictor', 'OnlinePredictor', 'OfflinePredictor',
'PredictorTowerBuilder',
'build_prediction_graph',
] ]
...@@ -144,75 +138,3 @@ class OfflinePredictor(OnlinePredictor): ...@@ -144,75 +138,3 @@ class OfflinePredictor(OnlinePredictor):
config.session_init.init(sess) config.session_init.init(sess)
super(OfflinePredictor, self).__init__( super(OfflinePredictor, self).__init__(
input_tensors, output_tensors, config.return_input, sess) input_tensors, output_tensors, config.return_input, sess)
class PredictorTowerBuilder(object):
"""
A builder which caches the predictor tower it has built.
"""
def __init__(self, build_tower_fn, prefix=''):
"""
Args:
build_tower_fn: a function that will be called inside each tower, taking tower id as the argument.
prefix: an extra prefix in tower name. The final tower prefix will be
determined by :meth:`TowerContext.get_predict_tower_name`.
"""
self._fn = build_tower_fn
self._prefix = prefix
@memoized
def build(self, tower):
"""
Args:
tower (int): the tower will be built on device '/gpu:{tower}', or
'/cpu:0' if tower is -1.
"""
towername = TowerContext.get_predict_tower_name(tower, self._prefix)
if self._prefix:
msg = "Building predictor graph {} on gpu={} with prefix='{}' ...".format(
towername, tower, self._prefix)
else:
msg = "Building predictor graph {} on gpu={} ...".format(towername, tower)
logger.info(msg)
# No matter where this get called, clear any existing name scope.
device = '/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0'
with tf.name_scope(None), \
freeze_collection(TOWER_FREEZE_KEYS), \
tf.device(device), \
TowerContext(towername, is_training=False):
self._fn(tower)
# useful only when the placeholders don't have tower prefix
# note that in DataParallel predictor, placeholders do have tower prefix
@staticmethod
def get_tensors_maybe_in_tower(placeholder_names, names, tower, prefix=''):
"""
Args:
placeholders (list): A list of __op__ name.
tower (int): relative GPU id.
"""
def maybe_inside_tower(name):
name = get_op_tensor_name(name)[0]
if name in placeholder_names:
return name
else:
# if the name is not a placeholder, use it's name in each tower
return TowerContext.get_predict_tower_name(tower, prefix) + '/' + name
names = list(map(maybe_inside_tower, names))
tensors = get_tensors_by_names(names)
return tensors
def build_prediction_graph(build_tower_fn, towers=[0], prefix=''):
"""
Execute `build_tower_fn` on each tower.
Just a wrapper on :class:`PredictorTowerBuilder` to run on several towers
together.
"""
builder = PredictorTowerBuilder(build_tower_fn, prefix)
for idx, t in enumerate(towers):
# The first variable scope may or may not reuse (depending on the existing
# context), but the rest have to reuse.
with tf.variable_scope(
tf.get_variable_scope(), reuse=True if idx > 0 else None):
builder.build(t)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment