Commit 0fdef168 authored by Yuxin Wu's avatar Yuxin Wu

remove PredictorFactory, and build offline predictor by tower_func

parent c2c895ab
......@@ -6,8 +6,7 @@ import tensorflow as tf
from contextlib import contextmanager
from ..utils import logger
from ..tfutils.tower import TowerContext, TowerFuncWrapper
from ..input_source import PlaceholderInput
from ..tfutils.tower import TowerContext
from .training import GraphBuilder
__all__ = ['SimplePredictBuilder']
......@@ -58,72 +57,3 @@ class SimplePredictBuilder(GraphBuilder):
inputs = input.get_input_tensors()
assert isinstance(inputs, (list, tuple)), inputs
return tower_fn(*inputs)
class PredictorFactory(object):
""" Make predictors from :class:`ModelDesc`."""
def __init__(self, model, vs_name=''):
"""
Args:
model (ModelDesc):
vs_name (str):
"""
self._model = model
self._vs_name = vs_name
self._names_built = {}
def build(self, tower_name, device, input=None):
"""
Args:
tower_name (str):
device(str):
input (InputSource): must be setup already. If None, will use InputDesc from the model.
"""
logger.info("Building predictor tower '{}' on device {} ...".format(tower_name, device))
assert tower_name not in self._names_built, \
"Prediction tower with name '{}' already exists!".format(tower_name)
with tf.device(device), \
TowerContext(tower_name, is_training=False):
inputs_desc = self._model.get_inputs_desc()
if input is None:
input = PlaceholderInput()
input.setup(inputs_desc)
inputs = input.get_input_tensors()
assert isinstance(inputs, (list, tuple)), inputs
def tower_func(*inputs):
self._model.build_graph(inputs)
tower_func = TowerFuncWrapper(tower_func, inputs_desc)
tower_func(*inputs)
self._names_built[tower_name] = tower_func.towers[0]
return self._names_built[tower_name]
def has_built(self, tower_name):
return tower_name in self._names_built
def get_predictor(self, input_names, output_names, tower):
"""
Args:
tower (int): use device '/gpu:{tower}' or use -1 for '/cpu:0'.
Returns:
an online predictor (which has to be used under a default session)
"""
tower_name = 'towerp{}'.format(tower)
device = '/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0'
# use a previously-built tower
# TODO check conflict with inference runner??
if tower_name not in self._names_built:
with tf.variable_scope(self._vs_name, reuse=True):
handle = self.build(tower_name, device)
else:
handle = self._names_built[tower_name]
in_tensors = handle.get_tensors(input_names)
out_tensors = handle.get_tensors(output_names)
from ..predict import OnlinePredictor # noqa TODO
return OnlinePredictor(in_tensors, out_tensors)
......@@ -146,7 +146,7 @@ class OnlinePredictor(PredictorBase):
class OfflinePredictor(OnlinePredictor):
""" A predictor built from a given config.
A sinlge-tower model will be built without any prefix. """
A single-tower model will be built without any prefix. """
def __init__(self, config):
"""
......@@ -156,9 +156,9 @@ class OfflinePredictor(OnlinePredictor):
self.graph = config._maybe_create_graph()
with self.graph.as_default():
input = PlaceholderInput()
input.setup(config.model.get_inputs_desc())
input.setup(config.inputs_desc)
with TowerContext('', is_training=False):
config.model.build_graph(input.get_input_tensors())
config.tower_func(*input.get_input_tensors())
input_tensors = get_tensors_by_names(config.input_names)
output_tensors = get_tensors_by_names(config.output_names)
......
......@@ -7,14 +7,17 @@ import six
from ..graph_builder import ModelDescBase
from ..tfutils import get_default_sess_config
from ..tfutils.tower import TowerFuncWrapper
from ..tfutils.sessinit import SessionInit, JustCurrentSession
from ..tfutils.sesscreate import NewSessionCreator
__all__ = ['PredictConfig']
class PredictConfig(object):
def __init__(self, model,
def __init__(self,
model=None,
inputs_desc=None,
tower_func=None,
session_creator=None,
session_init=None,
input_names=None,
......@@ -24,9 +27,12 @@ class PredictConfig(object):
):
"""
Args:
model (ModelDescBase): the model to use.
model (ModelDescBase): the model to be used to obtain inputs_desc and tower_func.
inputs_desc ([InputDesc]):
tower_func: a callable which takes input tensors
session_creator (tf.train.SessionCreator): how to create the
session. Defaults to :class:`sesscreate.NewSessionCreator()`.
session. Defaults to :class:`tf.train.ChiefSessionCreator()`.
session_init (SessionInit): how to initialize variables of the session.
Defaults to do nothing.
input_names (list): a list of input tensor names. Defaults to all
......@@ -36,11 +42,20 @@ class PredictConfig(object):
return_input (bool): same as in :attr:`PredictorBase.return_input`.
create_graph (bool): create a new graph, or use the default graph
when then predictor is first initialized.
You need to set either `model`, or `inputs_desc` plus `tower_func`.
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
self.model = model
assert_type(self.model, ModelDescBase)
if model is not None:
assert_type(model, ModelDescBase)
assert inputs_desc is None and tower_func is None
self.inputs_desc = model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(model.build_graph, self.inputs_desc)
else:
assert inputs_desc is not None and tower_func is not None
self.inputs_desc = inputs_desc
self.tower_func = TowerFuncWrapper(tower_func, inputs_desc)
if session_init is None:
session_init = JustCurrentSession()
......@@ -48,7 +63,7 @@ class PredictConfig(object):
assert_type(self.session_init, SessionInit)
if session_creator is None:
self.session_creator = NewSessionCreator(config=get_default_sess_config())
self.session_creator = tf.train.ChiefSessionCreator(config=get_default_sess_config())
else:
self.session_creator = session_creator
......
......@@ -5,7 +5,7 @@
import tensorflow as tf
from ..utils import logger
from ..graph_builder.predictor_factory import PredictorFactory
from ..graph_builder.predictor_factory import SimplePredictBuilder
from ..input_source import PlaceholderInput
from .base import OnlinePredictor
......@@ -28,13 +28,17 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self.return_input = config.return_input
with self.graph.as_default():
handles = []
factory = PredictorFactory(config.model, towers)
input = PlaceholderInput()
input.setup(config.inputs_desc)
for idx, t in enumerate(towers):
tower_name = 'tower' + str(t)
device = '/gpu:' + str(t)
with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0):
handles.append(factory.build(tower_name, device))
builder = SimplePredictBuilder(ns_name=tower_name, device=t)
builder.build(input, config.tower_func)
handles.append(config.tower_func.towers[-1])
self.sess = config.session_creator.create_session()
config.session_init.init(self.sess)
......@@ -87,15 +91,15 @@ class DataParallelOfflinePredictor(OnlinePredictor):
input_tensors = []
output_tensors = []
factory = PredictorFactory(config.model, towers)
for idx, t in enumerate(towers):
tower_name = 'tower' + str(t)
device = '/gpu:' + str(t)
input = PlaceholderInput(tower_name + '/')
input.setup(config.model.get_inputs_desc())
input.setup(config.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0):
h = factory.build(tower_name, device, )
builder = SimplePredictBuilder(ns_name=tower_name, device=t)
builder.build(input, config.tower_func)
h = config.tower_func.towers[-1]
input_tensors.extend(h.get_tensors(config.input_names))
output_tensors.extend(h.get_tensors(config.output_names))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment