Commit fa025551 authored by Yuxin Wu's avatar Yuxin Wu

refactor predictors.

parent dbfa9982
...@@ -183,6 +183,7 @@ class FeedfreeInferenceRunner(Triggerable): ...@@ -183,6 +183,7 @@ class FeedfreeInferenceRunner(Triggerable):
def _setup_graph(self): def _setup_graph(self):
self._find_input_tensors() # tensors self._find_input_tensors() # tensors
# TODO reuse predictor code
# overwrite the FeedfreeInferenceRunner name scope # overwrite the FeedfreeInferenceRunner name scope
with tf.variable_scope(tf.get_variable_scope(), reuse=True), \ with tf.variable_scope(tf.get_variable_scope(), reuse=True), \
tf.name_scope(None), \ tf.name_scope(None), \
...@@ -190,7 +191,7 @@ class FeedfreeInferenceRunner(Triggerable): ...@@ -190,7 +191,7 @@ class FeedfreeInferenceRunner(Triggerable):
def fn(_): def fn(_):
self.trainer.model.build_graph(self._input_tensors) self.trainer.model.build_graph(self._input_tensors)
build_prediction_graph(fn, [0], prefix=self._prefix) build_prediction_graph(fn, [0], prefix=self._prefix)
self._tower_prefix = TowerContext.get_predict_tower_name(self._prefix, 0) self._tower_prefix = TowerContext.get_predict_tower_name(0, self._prefix)
self._find_output_tensors() self._find_output_tensors()
......
...@@ -8,11 +8,15 @@ import tensorflow as tf ...@@ -8,11 +8,15 @@ import tensorflow as tf
import six import six
from ..utils import logger from ..utils import logger
from ..utils.argtools import memoized
from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..tfutils import get_tensors_by_names, TowerContext from ..tfutils import get_tensors_by_names, TowerContext
from ..tfutils.collection import freeze_collection
__all__ = ['PredictorBase', 'AsyncPredictorBase', __all__ = ['PredictorBase', 'AsyncPredictorBase',
'OnlinePredictor', 'OfflinePredictor', 'OnlinePredictor', 'OfflinePredictor',
'get_predict_func', 'get_predict_func',
'PredictorTowerBuilder',
'build_prediction_graph', 'build_prediction_graph',
] ]
...@@ -119,14 +123,15 @@ class OnlinePredictor(PredictorBase): ...@@ -119,14 +123,15 @@ class OnlinePredictor(PredictorBase):
class OfflinePredictor(OnlinePredictor): class OfflinePredictor(OnlinePredictor):
""" A predictor built from a given config, in a new graph. """ """ A predictor built from a given config.
A sinlge-tower model will be built without any prefix. """
def __init__(self, config): def __init__(self, config):
""" """
Args: Args:
config (PredictConfig): the config to use. config (PredictConfig): the config to use.
""" """
self.graph = tf.Graph() self.graph = config._maybe_create_graph()
with self.graph.as_default(): with self.graph.as_default():
input_placehdrs = config.model.get_reused_placehdrs() input_placehdrs = config.model.get_reused_placehdrs()
with TowerContext('', False): with TowerContext('', False):
...@@ -148,23 +153,52 @@ def get_predict_func(config): ...@@ -148,23 +153,52 @@ def get_predict_func(config):
return OfflinePredictor(config) return OfflinePredictor(config)
def build_prediction_graph(build_tower_fn, towers=[0], prefix=''): class PredictorTowerBuilder(object):
"""
A builder which caches the predictor tower it has built.
"""
def __init__(self, build_tower_fn, prefix=''):
""" """
Build graph on each tower.
Args: Args:
build_tower_fn: a function that will be called inside each tower, build_tower_fn: a function that will be called inside each tower, taking tower id as the argument.
taking tower id as the argument.
towers: a list of relative GPU id.
prefix: an extra prefix in tower name. The final tower prefix will be prefix: an extra prefix in tower name. The final tower prefix will be
determined by :meth:`TowerContext.get_predict_tower_name`. determined by :meth:`TowerContext.get_predict_tower_name`.
""" """
for idx, k in enumerate(towers): self._fn = build_tower_fn
logger.info( self._prefix = prefix
"Building prediction graph for towerid={} with prefix='{}' ...".format(k, prefix))
towername = TowerContext.get_predict_tower_name(prefix, k) @memoized
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \ def build(self, tower):
TowerContext(towername, is_training=False), \ """
tf.variable_scope(tf.get_variable_scope(), Args:
reuse=True if idx > 0 else None): tower (int): the tower will be built on device '/gpu:{tower}', or
build_tower_fn(k) '/cpu:0' if tower is -1.
"""
towername = TowerContext.get_predict_tower_name(tower, self._prefix)
if self._prefix:
msg = "Building predictor graph {} on gpu={} with prefix='{}' ...".format(
towername, tower, self._prefix)
else:
msg = "Building predictor graph {} on gpu={} ...".format(towername, tower)
logger.info(msg)
# No matter where this get called, clear any existing name scope.
with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS), \
tf.device('/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0'), \
TowerContext(towername, is_training=False):
self._fn(tower)
def build_prediction_graph(build_tower_fn, towers=[0], prefix=''):
"""
Execute `build_tower_fn` on each tower.
Just a wrapper on :class:`PredictorTowerBuilder` to run on several towers
together.
"""
builder = PredictorTowerBuilder(build_tower_fn, prefix)
for idx, t in enumerate(towers):
# The first variable scope may or may not reuse (depending on the existing
# context), but the rest have to reuse.
with tf.variable_scope(
tf.get_variable_scope(), reuse=True if idx > 0 else None):
builder.build(t)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# File: config.py # File: config.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import six import six
from ..models import ModelDesc from ..models import ModelDesc
...@@ -17,10 +18,12 @@ class PredictConfig(object): ...@@ -17,10 +18,12 @@ class PredictConfig(object):
def __init__(self, model, def __init__(self, model,
session_creator=None, session_creator=None,
session_init=None, session_init=None,
session_config=None,
input_names=None, input_names=None,
output_names=None, output_names=None,
return_input=False): return_input=False,
create_graph=True,
session_config=None, # deprecated
):
""" """
Args: Args:
model (ModelDesc): the model to use. model (ModelDesc): the model to use.
...@@ -32,7 +35,9 @@ class PredictConfig(object): ...@@ -32,7 +35,9 @@ class PredictConfig(object):
inputs of the model. inputs of the model.
output_names (list): a list of names of the output tensors to predict, the output_names (list): a list of names of the output tensors to predict, the
tensors can be any computable tensor in the graph. tensors can be any computable tensor in the graph.
return_input: same as in :attr:`PredictorBase.return_input`. return_input (bool): same as in :attr:`PredictorBase.return_input`.
create_graph (bool): create a new graph, or use the default graph
when then predictor is first initialized.
""" """
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
...@@ -68,3 +73,9 @@ class PredictConfig(object): ...@@ -68,3 +73,9 @@ class PredictConfig(object):
assert len(self.output_names), self.output_names assert len(self.output_names), self.output_names
self.return_input = bool(return_input) self.return_input = bool(return_input)
self.create_graph = bool(create_graph)
def _maybe_create_graph(self):
if self.create_graph:
return tf.Graph()
return tf.get_default_graph()
...@@ -3,9 +3,8 @@ ...@@ -3,9 +3,8 @@
# File: multigpu.py # File: multigpu.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf from ..utils import logger
from ..tfutils import get_tensors_by_names, TowerContext, get_op_tensor_name
from ..tfutils import get_tensors_by_names, TowerContext
from .base import OnlinePredictor, build_prediction_graph from .base import OnlinePredictor, build_prediction_graph
__all__ = ['MultiTowerOfflinePredictor', __all__ = ['MultiTowerOfflinePredictor',
...@@ -21,10 +20,12 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -21,10 +20,12 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
config (PredictConfig): the config to use. config (PredictConfig): the config to use.
towers: a list of relative GPU id. towers: a list of relative GPU id.
""" """
self.graph = tf.Graph() assert len(towers) > 0
self.graph = config._maybe_create_graph()
self.predictors = [] self.predictors = []
with self.graph.as_default(): with self.graph.as_default():
# TODO backup summary keys? placeholder_names = set([k.name for k in config.model.get_inputs_desc()])
def fn(_): def fn(_):
config.model.build_graph(config.model.get_reused_placehdrs()) config.model.build_graph(config.model.get_reused_placehdrs())
build_prediction_graph(fn, towers) build_prediction_graph(fn, towers)
...@@ -32,25 +33,46 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -32,25 +33,46 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self.sess = config.session_creator.create_session() self.sess = config.session_creator.create_session()
config.session_init.init(self.sess) config.session_init.init(self.sess)
input_tensors = get_tensors_by_names(config.input_names) get_tensor_fn = MultiTowerOfflinePredictor.get_tensors_maybe_in_tower
for k in towers: for k in towers:
output_tensors = get_tensors_by_names( input_tensors = get_tensor_fn(placeholder_names, config.input_names, k)
[TowerContext.get_predict_towre_name('', k) + '/' + n output_tensors = get_tensor_fn(placeholder_names, config.output_names, k)
for n in config.output_names])
self.predictors.append(OnlinePredictor( self.predictors.append(OnlinePredictor(
input_tensors, output_tensors, config.return_input, self.sess)) input_tensors, output_tensors, config.return_input, self.sess))
@staticmethod
def get_tensors_maybe_in_tower(placeholder_names, names, k):
def maybe_inside_tower(name):
name = get_op_tensor_name(name)[0]
if name in placeholder_names:
return name
else:
# if the name is not a placeholder, use it's name in each tower
return TowerContext.get_predict_tower_name(k) + '/' + name
names = map(maybe_inside_tower, names)
tensors = get_tensors_by_names(names)
return tensors
def _do_call(self, dp): def _do_call(self, dp):
# use the first tower for compatible PredictorBase interface # use the first tower for compatible PredictorBase interface
return self.predictors[0]._do_call(dp) return self.predictors[0]._do_call(dp)
def get_predictors(self, n): def get_predictor(self, n):
"""
Returns:
PredictorBase: the nth predictor on the nth tower.
"""
l = len(self.predictors)
if n >= l:
logger.warn("n > #towers, will assign predictor to GPU by round-robin")
return [self.predictors[k % l] for k in range(n)]
def get_predictors(self):
""" """
Returns: Returns:
PredictorBase: the nth predictor on the nth GPU. list[PredictorBase]: a list of predictor
""" """
return [self.predictors[k % len(self.predictors)] for k in range(n)] return self.predictors
class DataParallelOfflinePredictor(OnlinePredictor): class DataParallelOfflinePredictor(OnlinePredictor):
...@@ -66,7 +88,7 @@ class DataParallelOfflinePredictor(OnlinePredictor): ...@@ -66,7 +88,7 @@ class DataParallelOfflinePredictor(OnlinePredictor):
config (PredictConfig): the config to use. config (PredictConfig): the config to use.
towers: a list of relative GPU id. towers: a list of relative GPU id.
""" """
self.graph = tf.Graph() self.graph = config._maybe_create_graph()
with self.graph.as_default(): with self.graph.as_default():
input_names = [] input_names = []
output_tensors = [] output_tensors = []
......
...@@ -73,7 +73,7 @@ class TowerContext(object): ...@@ -73,7 +73,7 @@ class TowerContext(object):
return graph.get_tensor_by_name(newname) return graph.get_tensor_by_name(newname)
@staticmethod @staticmethod
def get_predict_tower_name(prefix, towerid=0): def get_predict_tower_name(towerid=0, prefix=''):
""" """
Args: Args:
prefix(str): an alphanumeric prefix. prefix(str): an alphanumeric prefix.
...@@ -91,6 +91,7 @@ class TowerContext(object): ...@@ -91,6 +91,7 @@ class TowerContext(object):
assert _CurrentTowerContext is None, \ assert _CurrentTowerContext is None, \
"Nesting TowerContext!" "Nesting TowerContext!"
_CurrentTowerContext = self _CurrentTowerContext = self
# TODO enter name_scope(None) first
if len(self._name): if len(self._name):
self._scope = tf.name_scope(self._name) self._scope = tf.name_scope(self._name)
return self._scope.__enter__() return self._scope.__enter__()
......
...@@ -124,6 +124,7 @@ class QueueInput(FeedfreeInput): ...@@ -124,6 +124,7 @@ class QueueInput(FeedfreeInput):
def size(self): def size(self):
return self.ds.size() return self.ds.size()
# TODO XXX use input data mapping. not all placeholders are needed
def _setup(self, trainer): def _setup(self, trainer):
self.input_placehdrs = trainer.model.get_reused_placehdrs() self.input_placehdrs = trainer.model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
......
...@@ -4,11 +4,8 @@ ...@@ -4,11 +4,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
from ..utils import SUMMARY_BACKUP_KEYS, PREDICT_TOWER from ..predict import (OnlinePredictor,
from ..tfutils.collection import freeze_collection PredictorTowerBuilder, MultiTowerOfflinePredictor)
from ..utils.argtools import memoized
from ..tfutils import get_tensors_by_names, get_op_tensor_name
from ..predict import OnlinePredictor, build_prediction_graph
__all__ = ['PredictorFactory'] __all__ = ['PredictorFactory']
...@@ -23,45 +20,27 @@ class PredictorFactory(object): ...@@ -23,45 +20,27 @@ class PredictorFactory(object):
""" """
self.model = trainer.model self.model = trainer.model
self.towers = trainer.config.predict_tower self.towers = trainer.config.predict_tower
def fn(_):
self.model.build_graph(self.model.get_reused_placehdrs())
self._tower_builder = PredictorTowerBuilder(fn)
assert isinstance(self.towers, list) assert isinstance(self.towers, list)
# TODO sess option # TODO sess option
def get_predictor(self, input_names, output_names, tower): def get_predictor(self, input_names, output_names, tower):
""" """
Args: Args:
tower (int): need the kth tower (not the gpu id) tower (int): need the kth tower (not the gpu id, but the id in TrainConfig.predict_tower)
Returns: Returns:
an online predictor (which has to be used under a default session) an online predictor (which has to be used under a default session)
""" """
self._build_predict_tower() tower = self.towers[tower] # TODO is it good?
tower = self.towers[tower] with tf.variable_scope(tf.get_variable_scope(), reuse=True):
# just ensure the tower exists. won't rebuild
self._tower_builder.build(tower)
placeholder_names = set([k.name for k in self.model.get_inputs_desc()]) placeholder_names = set([k.name for k in self.model.get_inputs_desc()])
get_tensor_fn = MultiTowerOfflinePredictor.get_tensors_maybe_in_tower
def get_name_in_tower(name): in_tensors = get_tensor_fn(placeholder_names, input_names, tower)
return PREDICT_TOWER + str(tower) + '/' + name out_tensors = get_tensor_fn(placeholder_names, output_names, tower)
return OnlinePredictor(in_tensors, out_tensors)
def maybe_inside_tower(name):
name = get_op_tensor_name(name)[0]
if name in placeholder_names:
return name
else:
return get_name_in_tower(name)
input_names = map(maybe_inside_tower, input_names)
raw_input_tensors = get_tensors_by_names(input_names)
output_names = map(get_name_in_tower, output_names)
output_tensors = get_tensors_by_names(output_names)
return OnlinePredictor(raw_input_tensors, output_tensors)
@memoized
def _build_predict_tower(self):
# build_predict_tower might get called anywhere, but 'PREDICT_TOWER'
# should always be the outermost name scope
with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS), \
tf.variable_scope(tf.get_variable_scope(), reuse=True):
def fn(_):
self.model.build_graph(self.model.get_reused_placehdrs())
build_prediction_graph(fn, self.towers)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment