Commit e1f9cc09 authored by Yuxin Wu's avatar Yuxin Wu

let DataParallelInference use PredictorFactory

parent 3951aaf7
...@@ -7,6 +7,7 @@ import tensorflow as tf ...@@ -7,6 +7,7 @@ import tensorflow as tf
from tensorflow.python.training.monitored_session \ from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession import _HookedSession as HookedSession
import itertools
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import tqdm import tqdm
import six import six
...@@ -15,13 +16,11 @@ from six.moves import range ...@@ -15,13 +16,11 @@ from six.moves import range
from ..utils import logger, get_tqdm_kwargs from ..utils import logger, get_tqdm_kwargs
from ..utils.develop import deprecated from ..utils.develop import deprecated
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..graph_builder.input_source_base import InputSource from ..graph_builder.input_source_base import InputSource
from ..graph_builder.input_source import ( from ..graph_builder.input_source import (
FeedInput, DataParallelFeedInput, FeedfreeInput, TensorInput) FeedInput, DataParallelFeedInput, FeedfreeInput, TensorInput)
from ..predict import PredictorTowerBuilder
from .base import Callback from .base import Callback
from .inference import Inferencer from .inference import Inferencer
...@@ -88,11 +87,12 @@ class InferenceRunnerBase(Callback): ...@@ -88,11 +87,12 @@ class InferenceRunnerBase(Callback):
self._extra_hooks = extra_hooks self._extra_hooks = extra_hooks
def _setup_graph(self): def _setup_graph(self):
self._input_source.setup(self.trainer.model.get_inputs_desc())
# Use predict_tower in train config. either gpuid or -1 # Use predict_tower in train config. either gpuid or -1
tower_id = self.trainer.config.predict_tower[0] tower_id = self.trainer.config.predict_tower[0]
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0' device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0'
tower_name = TowerContext.get_predict_tower_name(tower_id, prefix=self._prefix) tower_name = TowerContext.get_predict_tower_name(tower_id, prefix=self._prefix)
self._input_source.setup(self.trainer.model.get_inputs_desc())
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self._tower_handle = self.trainer.predictor_factory.build(tower_name, device, self._input_source) self._tower_handle = self.trainer.predictor_factory.build(tower_name, device, self._input_source)
...@@ -172,18 +172,15 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -172,18 +172,15 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self._gpus = gpus self._gpus = gpus
def _setup_graph(self): def _setup_graph(self):
model = self.trainer.model self._input_source.setup(self.trainer.model.get_inputs_desc())
self._input_source.setup(model.get_inputs_desc()) self._handles = []
# build graph
def build_tower(k):
# inputs (placeholders) for this tower only
model.build_graph(self._input_source)
builder = PredictorTowerBuilder(build_tower, prefix=self._prefix)
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for t in self._gpus: for t in self._gpus:
builder.build(t) tower_name = TowerContext.get_predict_tower_name(t, prefix=self._prefix)
device = '/gpu:{}'.format(t)
self._handles.append(
self.trainer.predictor_factory.build(
tower_name, device, self._input_source))
# setup feeds and hooks # setup feeds and hooks
self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs] self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs]
...@@ -191,15 +188,12 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -191,15 +188,12 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
cbs = self._input_source.get_callbacks() cbs = self._input_source.get_callbacks()
self._hooks_parallel.extend([CallbackToHook(cb) for cb in cbs]) self._hooks_parallel.extend([CallbackToHook(cb) for cb in cbs])
def _duplicate_names_across_towers(self, names):
ret = []
for t in self._gpus:
ret.extend([TowerContext.get_predict_tower_name(t, self._prefix) +
'/' + n for n in names])
return ret
class InferencerToHookDataParallel(InferencerToHook): class InferencerToHookDataParallel(InferencerToHook):
def __init__(self, inf, fetches, size): def __init__(self, inf, fetches, size):
"""
Args:
size(int): number of tensors to fetch per tower
"""
super(DataParallelInferenceRunner.InferencerToHookDataParallel, self).__init__(inf, fetches) super(DataParallelInferenceRunner.InferencerToHookDataParallel, self).__init__(inf, fetches)
assert len(self._fetches) % size == 0 assert len(self._fetches) % size == 0
self._sz = size self._sz = size
...@@ -213,16 +207,12 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -213,16 +207,12 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
def _build_hook_parallel(self, inf): def _build_hook_parallel(self, inf):
out_names = inf.get_output_tensors() out_names = inf.get_output_tensors()
sz = len(out_names) sz = len(out_names)
out_names = self._duplicate_names_across_towers(out_names) fetches = list(itertools.chain(*[t.get_tensors(out_names) for t in self._handles]))
fetches = get_tensors_by_names(out_names) return self.InferencerToHookDataParallel(inf, fetches, sz)
return DataParallelInferenceRunner.InferencerToHookDataParallel(
inf, fetches, sz)
def _build_hook(self, inf): def _build_hook(self, inf):
out_names = inf.get_output_tensors() out_names = inf.get_output_tensors()
names = [TowerContext.get_predict_tower_name( fetches = self._handles[0].get_tensors(out_names)
self._gpus[0], self._prefix) + '/' + n for n in out_names]
fetches = get_tensors_by_names(names)
return InferencerToHook(inf, fetches) return InferencerToHook(inf, fetches)
def _before_train(self): def _before_train(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment