Commit e61946b2 authored by Yuxin Wu's avatar Yuxin Wu

Remove SimplePredictBuilder

parent 25b31f68
...@@ -353,6 +353,34 @@ def process_signature(app, what, name, obj, options, signature, ...@@ -353,6 +353,34 @@ def process_signature(app, what, name, obj, options, signature,
# signature: arg list # signature: arg list
return signature, return_annotation return signature, return_annotation
_DEPRECATED_NAMES = set([
# deprecated stuff:
'TryResumeTraining',
'QueueInputTrainer',
'SimplePredictBuilder',
# renamed stuff:
'DumpTensor',
'DumpParamAsImage',
'StagingInputWrapper',
'PeriodicRunHooks',
'get_nr_gpu',
# deprecated or renamed symbolic code
'ImageSample',
'Deconv2D',
'get_scalar_var', 'psnr',
'prediction_incorrect', 'huber_loss',
# internal only
'apply_default_prefetch',
'average_grads',
'aggregate_grads',
'allreduce_grads',
'PrefetchOnGPUs',
])
def autodoc_skip_member(app, what, name, obj, skip, options): def autodoc_skip_member(app, what, name, obj, skip, options):
# we hide something deliberately # we hide something deliberately
if getattr(obj, '__HIDE_SPHINX_DOC__', False): if getattr(obj, '__HIDE_SPHINX_DOC__', False):
...@@ -363,31 +391,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options): ...@@ -363,31 +391,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
# https://github.com/sphinx-doc/sphinx/issues/4258 # https://github.com/sphinx-doc/sphinx/issues/4258
return False return False
# Hide some names that are deprecated or not intended to be used # Hide some names that are deprecated or not intended to be used
if name in [ if name in _DEPRECATED_NAMES:
# deprecated stuff:
'TryResumeTraining',
'QueueInputTrainer',
# renamed stuff:
'DumpTensor',
'DumpParamAsImage',
'StagingInputWrapper',
'PeriodicRunHooks',
'get_nr_gpu',
# deprecated or renamed symbolic code
'ImageSample',
'Deconv2D',
'get_scalar_var', 'psnr',
'prediction_incorrect', 'huber_loss',
# internal only
'apply_default_prefetch',
'average_grads',
'aggregate_grads',
'allreduce_grads',
'PrefetchOnGPUs',
]:
return True return True
if name in ['get_data', 'size', 'reset_state']: if name in ['get_data', 'size', 'reset_state']:
# skip these methods with empty docstring # skip these methods with empty docstring
......
...@@ -15,10 +15,10 @@ from six.moves import range ...@@ -15,10 +15,10 @@ from six.moves import range
from ..utils import logger from ..utils import logger
from ..utils.utils import get_tqdm_kwargs from ..utils.utils import get_tqdm_kwargs
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..tfutils.tower import PredictTowerContext
from ..input_source import ( from ..input_source import (
InputSource, FeedInput, QueueInput, StagingInput) InputSource, FeedInput, QueueInput, StagingInput)
from ..graph_builder.predict import SimplePredictBuilder
from .base import Callback from .base import Callback
from .group import Callbacks from .group import Callbacks
...@@ -28,6 +28,10 @@ __all__ = ['InferenceRunnerBase', 'InferenceRunner', ...@@ -28,6 +28,10 @@ __all__ = ['InferenceRunnerBase', 'InferenceRunner',
'DataParallelInferenceRunner'] 'DataParallelInferenceRunner']
def _device_from_int(dev):
return '/gpu:{}'.format(dev) if dev >= 0 else '/cpu:0'
class InferencerToHook(tf.train.SessionRunHook): class InferencerToHook(tf.train.SessionRunHook):
def __init__(self, inf, fetches): def __init__(self, inf, fetches):
self._inf = inf self._inf = inf
...@@ -94,9 +98,9 @@ class InferenceRunnerBase(Callback): ...@@ -94,9 +98,9 @@ class InferenceRunnerBase(Callback):
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks) self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
self._input_callbacks.before_train() self._input_callbacks.before_train()
if self._size > 0: if self._size > 0:
logger.info("InferenceRunner will eval {} iterations".format(self._size)) logger.info("[InferenceRunner] Will eval {} iterations".format(self._size))
else: else:
logger.warn("InferenceRunner got an input with unknown size! It will iterate until OutOfRangeError!") logger.warn("[InferenceRunner] Got an InputSource with unknown size! Will iterate until OutOfRangeError!")
def _after_train(self): def _after_train(self):
self._input_callbacks.after_train() self._input_callbacks.after_train()
...@@ -122,7 +126,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -122,7 +126,7 @@ class InferenceRunner(InferenceRunnerBase):
assert isinstance(input, InputSource), input assert isinstance(input, InputSource), input
assert not isinstance(input, StagingInput), input assert not isinstance(input, StagingInput), input
self._tower_name = tower_name self._tower_name = tower_name
self._device = device self._device = _device_from_int(device)
super(InferenceRunner, self).__init__(input, infs) super(InferenceRunner, self).__init__(input, infs)
def _build_hook(self, inf): def _build_hook(self, inf):
...@@ -131,16 +135,17 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -131,16 +135,17 @@ class InferenceRunner(InferenceRunnerBase):
return InferencerToHook(inf, fetches) return InferencerToHook(inf, fetches)
def _setup_graph(self): def _setup_graph(self):
device = self._device
assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!" assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!"
input_callbacks = self._input_source.setup(self.trainer.inputs_desc) tower_func = self.trainer.tower_func
input_callbacks = self._input_source.setup(tower_func.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True): logger.info("[InferenceRunner] Building tower '{}' on device {} ...".format(self._tower_name, self._device))
SimplePredictBuilder( with tf.variable_scope(tf.get_variable_scope(), reuse=True), \
ns_name=self._tower_name, tf.device(self._device), \
vs_name=self.trainer._main_tower_vs_name, device=device).build( PredictTowerContext(
self._input_source, self.trainer.tower_func) self._tower_name, vs_name=self.trainer._main_tower_vs_name):
self._tower_handle = self.trainer.tower_func.towers[-1] tower_func(*self._input_source.get_input_tensors())
self._tower_handle = tower_func.towers[-1]
for h in [self._build_hook(inf) for inf in self.infs]: for h in [self._build_hook(inf) for inf in self.infs]:
self.register_hook(h) self.register_hook(h)
...@@ -178,7 +183,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -178,7 +183,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
It will run the remainder (when the total size of input is not a multiple of #GPU) It will run the remainder (when the total size of input is not a multiple of #GPU)
sequentially. sequentially.
""" """
def __init__(self, input, infs, gpus): def __init__(self, input, infs, gpus, tower_name='InferenceTower'):
""" """
Args: Args:
input (DataFlow or QueueInput) input (DataFlow or QueueInput)
...@@ -186,13 +191,14 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -186,13 +191,14 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
""" """
if isinstance(gpus, int): if isinstance(gpus, int):
gpus = list(range(gpus)) gpus = list(range(gpus))
self._tower_names = ['InferenceTower{}'.format(k) for k in range(len(gpus))] self._devices = [_device_from_int(k) for k in gpus]
self._tower_names = ['{}{}'.format(tower_name, k) for k in range(len(gpus))]
if isinstance(input, DataFlow): if isinstance(input, DataFlow):
input = QueueInput(input) input = QueueInput(input)
assert isinstance(input, QueueInput), input assert isinstance(input, QueueInput), input
super(DataParallelInferenceRunner, self).__init__(input, infs) super(DataParallelInferenceRunner, self).__init__(input, infs)
assert self._size > 0, "Input for DataParallelInferenceRunner must have a size!" assert self._size > 0, "Input for DataParallelInferenceRunner must have a size!"
self._gpus = gpus
self._hooks = [] self._hooks = []
self._hooks_parallel = [] self._hooks_parallel = []
...@@ -201,15 +207,14 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -201,15 +207,14 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self._handles = [] self._handles = []
assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!" assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!"
input_callbacks = self._input_source.setup(self.trainer.inputs_desc) tower_func = self.trainer.tower_func
input_callbacks = self._input_source.setup(tower_func.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for idx, t in enumerate(self._gpus): for idx, dev in enumerate(self._devices):
tower_name = self._tower_names[idx] with tf.device(dev), PredictTowerContext(
SimplePredictBuilder( self._tower_names[idx], vs_name=self.trainer._main_tower_vs_name):
ns_name=tower_name, tower_func(*self._input_source.get_input_tensors())
vs_name=self.trainer._main_tower_vs_name, device=t).build( self._handles.append(tower_func.towers[-1])
self._input_source, self.trainer.tower_func)
self._handles.append(self.trainer.tower_func.towers[-1])
# setup callbacks and hooks # setup callbacks and hooks
self._input_callbacks = Callbacks(input_callbacks) self._input_callbacks = Callbacks(input_callbacks)
...@@ -267,7 +272,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -267,7 +272,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
inf.before_epoch() inf.before_epoch()
total = self._size total = self._size
nr_tower = len(self._gpus) nr_tower = len(self._devices)
self._input_source.reset_state() self._input_source.reset_state()
with _inference_context(): with _inference_context():
with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar: with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar:
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import tensorflow as tf import tensorflow as tf
from ..utils import logger from ..utils import logger
from ..utils.develop import deprecated
from ..tfutils.tower import PredictTowerContext from ..tfutils.tower import PredictTowerContext
from .training import GraphBuilder from .training import GraphBuilder
...@@ -14,6 +15,7 @@ class SimplePredictBuilder(GraphBuilder): ...@@ -14,6 +15,7 @@ class SimplePredictBuilder(GraphBuilder):
""" """
Single-tower predictor. Single-tower predictor.
""" """
@deprecated("Please use TowerContext to build it by yourself!", "2018-12-31")
def __init__(self, ns_name='', vs_name='', device=0): def __init__(self, ns_name='', vs_name='', device=0):
""" """
Args: Args:
......
...@@ -92,6 +92,7 @@ class InputSource(object): ...@@ -92,6 +92,7 @@ class InputSource(object):
Returns: Returns:
list[Callback]: extra callbacks needed by this InputSource. list[Callback]: extra callbacks needed by this InputSource.
callbacks of InputSource cannot use any `trigger*()` method.
""" """
self._setup(inputs_desc) self._setup(inputs_desc)
self._setup_done = True self._setup_done = True
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
import tensorflow as tf import tensorflow as tf
from ..utils import logger from ..utils import logger
from ..graph_builder.predict import SimplePredictBuilder
from ..graph_builder.model_desc import InputDesc from ..graph_builder.model_desc import InputDesc
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from ..tfutils.tower import PredictTowerContext
from .base import OnlinePredictor from .base import OnlinePredictor
__all__ = ['MultiTowerOfflinePredictor', __all__ = ['MultiTowerOfflinePredictor',
...@@ -14,7 +14,9 @@ __all__ = ['MultiTowerOfflinePredictor', ...@@ -14,7 +14,9 @@ __all__ = ['MultiTowerOfflinePredictor',
class MultiTowerOfflinePredictor(OnlinePredictor): class MultiTowerOfflinePredictor(OnlinePredictor):
""" A multi-tower multi-GPU predictor. """ """ A multi-tower multi-GPU predictor.
It builds one predictor for each tower.
"""
def __init__(self, config, towers): def __init__(self, config, towers):
""" """
...@@ -35,9 +37,10 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -35,9 +37,10 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
for idx, t in enumerate(towers): for idx, t in enumerate(towers):
tower_name = 'tower' + str(t) tower_name = 'tower' + str(t)
with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0): with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0), \
builder = SimplePredictBuilder(ns_name=tower_name, device=t) tf.device('/gpu:{}'.format(t)), \
builder.build(input, config.tower_func) PredictTowerContext(tower_name):
config.tower_func(*input.get_input_tensors())
handles.append(config.tower_func.towers[-1]) handles.append(config.tower_func.towers[-1])
self.sess = config.session_creator.create_session() self.sess = config.session_creator.create_session()
...@@ -73,7 +76,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -73,7 +76,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
class DataParallelOfflinePredictor(OnlinePredictor): class DataParallelOfflinePredictor(OnlinePredictor):
""" """
A data-parallel predictor. A data-parallel predictor. It builds one predictor that utilizes all GPUs.
Note that it doesn't split/concat inputs/outputs automatically. Note that it doesn't split/concat inputs/outputs automatically.
Instead, its inputs are: Instead, its inputs are:
``[input[0] in tower[0], input[1] in tower[0], ..., input[0] in tower[1], input[1] in tower[1], ...]`` ``[input[0] in tower[0], input[1] in tower[0], ..., input[0] in tower[1], input[1] in tower[1], ...]``
...@@ -99,9 +103,10 @@ class DataParallelOfflinePredictor(OnlinePredictor): ...@@ -99,9 +103,10 @@ class DataParallelOfflinePredictor(OnlinePredictor):
input = PlaceholderInput() input = PlaceholderInput()
input.setup(inputs_desc) input.setup(inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0): with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0), \
builder = SimplePredictBuilder(ns_name=tower_name, device=t) tf.device('/gpu:{}'.format(t)), \
builder.build(input, config.tower_func) PredictTowerContext(tower_name):
config.tower_func(*input.get_input_tensors())
h = config.tower_func.towers[-1] h = config.tower_func.towers[-1]
input_tensors.extend(h.get_tensors(config.input_names)) input_tensors.extend(h.get_tensors(config.input_names))
output_tensors.extend(h.get_tensors(config.output_names)) output_tensors.extend(h.get_tensors(config.output_names))
......
...@@ -6,11 +6,10 @@ import six ...@@ -6,11 +6,10 @@ import six
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from ..utils.argtools import call_only_once, memoized from ..utils.argtools import call_only_once, memoized
from ..graph_builder.predict import SimplePredictBuilder
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from ..predict.base import OnlinePredictor from ..predict.base import OnlinePredictor
from ..tfutils.tower import TowerFuncWrapper, get_current_tower_context from ..tfutils.tower import TowerFuncWrapper, get_current_tower_context, PredictTowerContext
from ..tfutils.gradproc import FilterNoneGrad from ..tfutils.gradproc import FilterNoneGrad
from .base import Trainer from .base import Trainer
...@@ -94,6 +93,7 @@ class TowerTrainer(Trainer): ...@@ -94,6 +93,7 @@ class TowerTrainer(Trainer):
""" """
assert self.tower_func is not None, "Must set tower_func on the trainer to use get_predictor()!" assert self.tower_func is not None, "Must set tower_func on the trainer to use get_predictor()!"
tower_name = 'tower-pred-{}'.format(device) if device >= 0 else 'tower-pred-cpu' tower_name = 'tower-pred-{}'.format(device) if device >= 0 else 'tower-pred-cpu'
device = '/gpu:{}'.format(device) if device >= 0 else '/cpu:0'
try: try:
tower = self.tower_func.towers[tower_name] tower = self.tower_func.towers[tower_name]
...@@ -105,10 +105,10 @@ class TowerTrainer(Trainer): ...@@ -105,10 +105,10 @@ class TowerTrainer(Trainer):
input = PlaceholderInput() input = PlaceholderInput()
input.setup(self.inputs_desc) input.setup(self.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True), \
SimplePredictBuilder( tf.device(device), PredictTowerContext(
ns_name=tower_name, vs_name=self._main_tower_vs_name, tower_name, vs_name=self._main_tower_vs_name):
device=device).build(input, self.tower_func) self.tower_func(*input.get_input_tensors())
tower = self.tower_func.towers[tower_name] tower = self.tower_func.towers[tower_name]
input_tensors = tower.get_tensors(input_names) input_tensors = tower.get_tensors(input_names)
output_tensors = tower.get_tensors(output_names) output_tensors = tower.get_tensors(output_names)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment