Commit 0fdef168 authored by Yuxin Wu's avatar Yuxin Wu

remove PredictorFactory, and build offline predictor by tower_func

parent c2c895ab
...@@ -6,8 +6,7 @@ import tensorflow as tf ...@@ -6,8 +6,7 @@ import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
from ..utils import logger from ..utils import logger
from ..tfutils.tower import TowerContext, TowerFuncWrapper from ..tfutils.tower import TowerContext
from ..input_source import PlaceholderInput
from .training import GraphBuilder from .training import GraphBuilder
__all__ = ['SimplePredictBuilder'] __all__ = ['SimplePredictBuilder']
...@@ -58,72 +57,3 @@ class SimplePredictBuilder(GraphBuilder): ...@@ -58,72 +57,3 @@ class SimplePredictBuilder(GraphBuilder):
inputs = input.get_input_tensors() inputs = input.get_input_tensors()
assert isinstance(inputs, (list, tuple)), inputs assert isinstance(inputs, (list, tuple)), inputs
return tower_fn(*inputs) return tower_fn(*inputs)
class PredictorFactory(object):
""" Make predictors from :class:`ModelDesc`."""
def __init__(self, model, vs_name=''):
"""
Args:
model (ModelDesc):
vs_name (str):
"""
self._model = model
self._vs_name = vs_name
self._names_built = {}
def build(self, tower_name, device, input=None):
"""
Args:
tower_name (str):
device(str):
input (InputSource): must be setup already. If None, will use InputDesc from the model.
"""
logger.info("Building predictor tower '{}' on device {} ...".format(tower_name, device))
assert tower_name not in self._names_built, \
"Prediction tower with name '{}' already exists!".format(tower_name)
with tf.device(device), \
TowerContext(tower_name, is_training=False):
inputs_desc = self._model.get_inputs_desc()
if input is None:
input = PlaceholderInput()
input.setup(inputs_desc)
inputs = input.get_input_tensors()
assert isinstance(inputs, (list, tuple)), inputs
def tower_func(*inputs):
self._model.build_graph(inputs)
tower_func = TowerFuncWrapper(tower_func, inputs_desc)
tower_func(*inputs)
self._names_built[tower_name] = tower_func.towers[0]
return self._names_built[tower_name]
def has_built(self, tower_name):
return tower_name in self._names_built
def get_predictor(self, input_names, output_names, tower):
"""
Args:
tower (int): use device '/gpu:{tower}' or use -1 for '/cpu:0'.
Returns:
an online predictor (which has to be used under a default session)
"""
tower_name = 'towerp{}'.format(tower)
device = '/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0'
# use a previously-built tower
# TODO check conflict with inference runner??
if tower_name not in self._names_built:
with tf.variable_scope(self._vs_name, reuse=True):
handle = self.build(tower_name, device)
else:
handle = self._names_built[tower_name]
in_tensors = handle.get_tensors(input_names)
out_tensors = handle.get_tensors(output_names)
from ..predict import OnlinePredictor # noqa TODO
return OnlinePredictor(in_tensors, out_tensors)
...@@ -146,7 +146,7 @@ class OnlinePredictor(PredictorBase): ...@@ -146,7 +146,7 @@ class OnlinePredictor(PredictorBase):
class OfflinePredictor(OnlinePredictor): class OfflinePredictor(OnlinePredictor):
""" A predictor built from a given config. """ A predictor built from a given config.
A sinlge-tower model will be built without any prefix. """ A single-tower model will be built without any prefix. """
def __init__(self, config): def __init__(self, config):
""" """
...@@ -156,9 +156,9 @@ class OfflinePredictor(OnlinePredictor): ...@@ -156,9 +156,9 @@ class OfflinePredictor(OnlinePredictor):
self.graph = config._maybe_create_graph() self.graph = config._maybe_create_graph()
with self.graph.as_default(): with self.graph.as_default():
input = PlaceholderInput() input = PlaceholderInput()
input.setup(config.model.get_inputs_desc()) input.setup(config.inputs_desc)
with TowerContext('', is_training=False): with TowerContext('', is_training=False):
config.model.build_graph(input.get_input_tensors()) config.tower_func(*input.get_input_tensors())
input_tensors = get_tensors_by_names(config.input_names) input_tensors = get_tensors_by_names(config.input_names)
output_tensors = get_tensors_by_names(config.output_names) output_tensors = get_tensors_by_names(config.output_names)
......
...@@ -7,14 +7,17 @@ import six ...@@ -7,14 +7,17 @@ import six
from ..graph_builder import ModelDescBase from ..graph_builder import ModelDescBase
from ..tfutils import get_default_sess_config from ..tfutils import get_default_sess_config
from ..tfutils.tower import TowerFuncWrapper
from ..tfutils.sessinit import SessionInit, JustCurrentSession from ..tfutils.sessinit import SessionInit, JustCurrentSession
from ..tfutils.sesscreate import NewSessionCreator
__all__ = ['PredictConfig'] __all__ = ['PredictConfig']
class PredictConfig(object): class PredictConfig(object):
def __init__(self, model, def __init__(self,
model=None,
inputs_desc=None,
tower_func=None,
session_creator=None, session_creator=None,
session_init=None, session_init=None,
input_names=None, input_names=None,
...@@ -24,9 +27,12 @@ class PredictConfig(object): ...@@ -24,9 +27,12 @@ class PredictConfig(object):
): ):
""" """
Args: Args:
model (ModelDescBase): the model to use. model (ModelDescBase): the model to be used to obtain inputs_desc and tower_func.
inputs_desc ([InputDesc]):
tower_func: a callable which takes input tensors
session_creator (tf.train.SessionCreator): how to create the session_creator (tf.train.SessionCreator): how to create the
session. Defaults to :class:`sesscreate.NewSessionCreator()`. session. Defaults to :class:`tf.train.ChiefSessionCreator()`.
session_init (SessionInit): how to initialize variables of the session. session_init (SessionInit): how to initialize variables of the session.
Defaults to do nothing. Defaults to do nothing.
input_names (list): a list of input tensor names. Defaults to all input_names (list): a list of input tensor names. Defaults to all
...@@ -36,11 +42,20 @@ class PredictConfig(object): ...@@ -36,11 +42,20 @@ class PredictConfig(object):
return_input (bool): same as in :attr:`PredictorBase.return_input`. return_input (bool): same as in :attr:`PredictorBase.return_input`.
create_graph (bool): create a new graph, or use the default graph create_graph (bool): create a new graph, or use the default graph
when then predictor is first initialized. when then predictor is first initialized.
You need to set either `model`, or `inputs_desc` plus `tower_func`.
""" """
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
self.model = model if model is not None:
assert_type(self.model, ModelDescBase) assert_type(model, ModelDescBase)
assert inputs_desc is None and tower_func is None
self.inputs_desc = model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(model.build_graph, self.inputs_desc)
else:
assert inputs_desc is not None and tower_func is not None
self.inputs_desc = inputs_desc
self.tower_func = TowerFuncWrapper(tower_func, inputs_desc)
if session_init is None: if session_init is None:
session_init = JustCurrentSession() session_init = JustCurrentSession()
...@@ -48,7 +63,7 @@ class PredictConfig(object): ...@@ -48,7 +63,7 @@ class PredictConfig(object):
assert_type(self.session_init, SessionInit) assert_type(self.session_init, SessionInit)
if session_creator is None: if session_creator is None:
self.session_creator = NewSessionCreator(config=get_default_sess_config()) self.session_creator = tf.train.ChiefSessionCreator(config=get_default_sess_config())
else: else:
self.session_creator = session_creator self.session_creator = session_creator
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import tensorflow as tf import tensorflow as tf
from ..utils import logger from ..utils import logger
from ..graph_builder.predictor_factory import PredictorFactory from ..graph_builder.predictor_factory import SimplePredictBuilder
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from .base import OnlinePredictor from .base import OnlinePredictor
...@@ -28,13 +28,17 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -28,13 +28,17 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self.return_input = config.return_input self.return_input = config.return_input
with self.graph.as_default(): with self.graph.as_default():
handles = [] handles = []
factory = PredictorFactory(config.model, towers)
input = PlaceholderInput()
input.setup(config.inputs_desc)
for idx, t in enumerate(towers): for idx, t in enumerate(towers):
tower_name = 'tower' + str(t) tower_name = 'tower' + str(t)
device = '/gpu:' + str(t)
with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0): with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0):
handles.append(factory.build(tower_name, device)) builder = SimplePredictBuilder(ns_name=tower_name, device=t)
builder.build(input, config.tower_func)
handles.append(config.tower_func.towers[-1])
self.sess = config.session_creator.create_session() self.sess = config.session_creator.create_session()
config.session_init.init(self.sess) config.session_init.init(self.sess)
...@@ -87,15 +91,15 @@ class DataParallelOfflinePredictor(OnlinePredictor): ...@@ -87,15 +91,15 @@ class DataParallelOfflinePredictor(OnlinePredictor):
input_tensors = [] input_tensors = []
output_tensors = [] output_tensors = []
factory = PredictorFactory(config.model, towers)
for idx, t in enumerate(towers): for idx, t in enumerate(towers):
tower_name = 'tower' + str(t) tower_name = 'tower' + str(t)
device = '/gpu:' + str(t)
input = PlaceholderInput(tower_name + '/') input = PlaceholderInput(tower_name + '/')
input.setup(config.model.get_inputs_desc()) input.setup(config.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0): with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0):
h = factory.build(tower_name, device, ) builder = SimplePredictBuilder(ns_name=tower_name, device=t)
builder.build(input, config.tower_func)
h = config.tower_func.towers[-1]
input_tensors.extend(h.get_tensors(config.input_names)) input_tensors.extend(h.get_tensors(config.input_names))
output_tensors.extend(h.get_tensors(config.output_names)) output_tensors.extend(h.get_tensors(config.output_names))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment