Commit 7699fd9b authored by Yuxin Wu's avatar Yuxin Wu

let multi-gpu OfflinePredictor use PredictorFactory

parent e1f9cc09
......@@ -21,13 +21,31 @@ from ..utils import logger
from ..utils.concurrency import ShareSessionThread
from ..callbacks.base import Callback
__all__ = ['FeedInput', 'DataParallelFeedInput',
__all__ = ['PlaceholderInput', 'FeedInput', 'DataParallelFeedInput',
'FeedfreeInput',
'QueueInput', 'BatchQueueInput',
'ZMQInput', 'DummyConstantInput', 'TensorInput',
'StagingInputWrapper']
class PlaceholderInput(InputSource):
"""
Just produce placeholders as input tensors.
"""
def __init__(self, prefix=''):
"""
Args:
prefix(str): an optional prefix to add to the placeholder.
"""
self._prefix = prefix
def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder(prefix=self._prefix) for v in inputs]
def _get_input_tensors(self):
return self._all_placehdrs
class FeedInput(InputSource):
""" Input by iterating over a DataFlow and feed datapoints. """
......@@ -60,16 +78,16 @@ class FeedInput(InputSource):
return self.ds.size()
def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._all_placehdrs = [v.build_placeholder(prefix=self._prefix) for v in inputs]
self._cb = self._FeedCallback(self._repeat_ds, self._all_placehdrs)
self.reset_state()
def _reset_state(self):
self._cb._reset()
def _get_input_tensors(self):
return self._all_placehdrs
def _reset_state(self):
self._cb._reset()
def _get_callbacks(self):
return [self._cb]
......
......@@ -54,7 +54,6 @@ class InputSource(object):
# TODO
self._reset_state()
@abstractmethod
def _reset_state(self):
pass
......
......@@ -7,7 +7,6 @@ from ..utils import logger
from ..tfutils.common import get_op_tensor_name, get_tensors_by_names
from ..tfutils.tower import TowerContext
from ..tfutils.collection import freeze_collection
from ..predict import OnlinePredictor
from ..utils.naming import TOWER_FREEZE_KEYS
__all__ = ['PredictorFactory']
......@@ -41,7 +40,7 @@ class PredictorTowerHandle(object):
class PredictorFactory(object):
""" Make predictors from :class:`ModelDesc`."""
def __init__(self, model, towers, vs_name):
def __init__(self, model, towers, vs_name=''):
"""
Args:
model (ModelDesc):
......@@ -97,4 +96,5 @@ class PredictorFactory(object):
in_tensors = handle.get_tensors(input_names)
out_tensors = handle.get_tensors(output_names)
from ..predict import OnlinePredictor # noqa TODO
return OnlinePredictor(in_tensors, out_tensors)
......@@ -3,9 +3,12 @@
# File: multigpu.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from ..utils import logger
from ..tfutils import get_tensors_by_names, TowerContext
from .base import OnlinePredictor, build_prediction_graph, PredictorTowerBuilder
from ..tfutils import TowerContext
from ..graph_builder.predictor_factory import PredictorFactory
from ..graph_builder.input_source import PlaceholderInput
from .base import OnlinePredictor
__all__ = ['MultiTowerOfflinePredictor',
'DataParallelOfflinePredictor']
......@@ -23,20 +26,24 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
assert len(towers) > 0
self.graph = config._maybe_create_graph()
self.predictors = []
self.return_input = config.return_input
with self.graph.as_default():
placeholder_names = set([k.name for k in config.model.get_inputs_desc()])
handles = []
factory = PredictorFactory(config.model, towers)
for idx, t in enumerate(towers):
tower_name = TowerContext.get_predict_tower_name(t)
device = '/gpu:' + str(t)
def fn(_):
config.model.build_graph(config.model.get_reused_placehdrs())
build_prediction_graph(fn, towers)
# TODO smarter TowerContext?
with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0):
handles.append(factory.build(tower_name, device))
self.sess = config.session_creator.create_session()
config.session_init.init(self.sess)
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
for k in towers:
input_tensors = get_tensor_fn(placeholder_names, config.input_names, k)
output_tensors = get_tensor_fn(placeholder_names, config.output_names, k)
for h in handles:
input_tensors = h.get_tensors(config.input_names)
output_tensors = h.get_tensors(config.output_names)
self.predictors.append(OnlinePredictor(
input_tensors, output_tensors, config.return_input, self.sess))
......@@ -79,23 +86,20 @@ class DataParallelOfflinePredictor(OnlinePredictor):
"""
self.graph = config._maybe_create_graph()
with self.graph.as_default():
input_names = []
input_tensors = []
output_tensors = []
def build_tower(k):
towername = TowerContext.get_predict_tower_name(k)
# inputs (placeholders) for this tower only
input_tensors = config.model.build_placeholders(prefix=towername + '/')
config.model.build_graph(input_tensors)
input_names.extend([t.name for t in input_tensors])
output_tensors.extend(get_tensors_by_names(
[towername + '/' + n
for n in config.output_names]))
build_prediction_graph(build_tower, towers)
input_tensors = get_tensors_by_names(input_names)
factory = PredictorFactory(config.model, towers)
for idx, t in enumerate(towers):
tower_name = TowerContext.get_predict_tower_name(t)
device = '/gpu:' + str(t)
input = PlaceholderInput(tower_name + '/')
input.setup(config.model.get_inputs_desc())
with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0):
h = factory.build(tower_name, device, )
input_tensors.extend(h.get_tensors(config.input_names))
output_tensors.extend(h.get_tensors(config.output_names))
sess = config.session_creator.create_session()
config.session_init.init(sess)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment