Commit 74c80d57 authored by Yuxin Wu's avatar Yuxin Wu

refactor InferenceRunner. (fix #139)

parent b9a15df7
...@@ -7,7 +7,7 @@ Neural Network Toolbox on TensorFlow. ...@@ -7,7 +7,7 @@ Neural Network Toolbox on TensorFlow.
See some [examples](examples) to learn about the framework: See some [examples](examples) to learn about the framework:
### Vision: ### Vision:
+ [Train ResNet on ImageNet / Cifar10 / SVHN](examples/ResNet) + [Multi-GPU training of ResNet on ImageNet](examples/ResNet)
+ [Generative Adversarial Network(GAN) variants](examples/GAN), including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image, CycleGAN. + [Generative Adversarial Network(GAN) variants](examples/GAN), including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image, CycleGAN.
+ [DoReFa-Net: train binary / low-bitwidth CNN on ImageNet](examples/DoReFa-Net) + [DoReFa-Net: train binary / low-bitwidth CNN on ImageNet](examples/DoReFa-Net)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection(HED)](examples/HED) + [Fully-convolutional Network for Holistically-Nested Edge Detection(HED)](examples/HED)
...@@ -43,7 +43,7 @@ It's Yet Another TF wrapper, but different in: ...@@ -43,7 +43,7 @@ It's Yet Another TF wrapper, but different in:
+ Data-parallel distributed training is off-the-shelf to use. It is as slow as Google's official benchmark. + Data-parallel distributed training is off-the-shelf to use. It is as slow as Google's official benchmark.
3. Focus on large datasets. 3. Focus on large datasets.
+ It's painful to read/preprocess data through TF. Use __DataFlow__ to load large datasets (e.g. ImageNet) in __pure Python__ with multi-process prefetch. + It's painful to read/preprocess data through TF. Use __DataFlow__ to load large datasets (e.g. ImageNet) in __pure Python__ with autoparallelization.
+ DataFlow has a unified interface, so you can compose and reuse them to perform complex preprocessing. + DataFlow has a unified interface, so you can compose and reuse them to perform complex preprocessing.
4. Interface of extensible __Callbacks__. 4. Interface of extensible __Callbacks__.
......
...@@ -8,9 +8,8 @@ from tensorflow.python.training.monitored_session \ ...@@ -8,9 +8,8 @@ from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession import _HookedSession as HookedSession
import itertools import itertools
from abc import ABCMeta, abstractmethod from contextlib import contextmanager
import tqdm import tqdm
import six
from six.moves import range from six.moves import range
from ..utils import logger from ..utils import logger
...@@ -42,16 +41,29 @@ class InferencerToHook(tf.train.SessionRunHook): ...@@ -42,16 +41,29 @@ class InferencerToHook(tf.train.SessionRunHook):
self._inf.on_fetches(run_values.results) self._inf.on_fetches(run_values.results)
@six.add_metaclass(ABCMeta) @contextmanager
def _inference_context():
msg = "You might need to check your input implementation."
try:
yield
except (StopIteration,
tf.errors.CancelledError,
tf.errors.OutOfRangeError):
logger.error(
"[InferenceRunner] input stopped before reaching its size()! " + msg)
raise
class InferenceRunnerBase(Callback): class InferenceRunnerBase(Callback):
""" Base methods for inference runner""" """ Base class for inference runner.
def __init__(self, input, infs, tower_name='InferenceTower', extra_hooks=None, prefix=None): Please note that InferenceRunner will use `input.size()` to determine
how much iterations to run, so you want it to be accurate.
"""
def __init__(self, input, infs, extra_hooks=None):
""" """
Args: Args:
input (InputSource): the input to use. Must have ``size()``. input (InputSource): the input to use. Must have ``size()``.
infs (list[Inferencer]): list of :class:`Inferencer` to run. infs (list[Inferencer]): list of :class:`Inferencer` to run.
tower_name(str): name scope to build the tower. Must be set
differently if more than one :class:`InferenceRunner` are used.
extra_hooks (list[SessionRunHook]): extra :class:`SessionRunHook` to run with the evaluation. extra_hooks (list[SessionRunHook]): extra :class:`SessionRunHook` to run with the evaluation.
""" """
self._input_source = input self._input_source = input
...@@ -66,14 +78,46 @@ class InferenceRunnerBase(Callback): ...@@ -66,14 +78,46 @@ class InferenceRunnerBase(Callback):
self._size = input.size() self._size = input.size()
except NotImplementedError: except NotImplementedError:
raise ValueError("Input used in InferenceRunner must have a size!") raise ValueError("Input used in InferenceRunner must have a size!")
self._tower_name = tower_name
if prefix is not None:
self._tower_name = 'InferenceTower' + prefix
if extra_hooks is None: if extra_hooks is None:
extra_hooks = [] extra_hooks = []
self._extra_hooks = extra_hooks self._extra_hooks = extra_hooks
def _before_train(self):
self._hooks.extend(self._extra_hooks)
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
self._input_callbacks.before_train()
def _after_train(self):
self._input_callbacks.after_train()
class InferenceRunner(InferenceRunnerBase):
"""
A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`.
"""
def __init__(self, input, infs, tower_name='InferenceTower', extra_hooks=None):
"""
Args:
input (InputSource or DataFlow): The :class:`InputSource` to run
inference on. If given a DataFlow, will use :class:`FeedInput`.
infs (list): a list of :class:`Inferencer` instances.
tower_name (str): the name scope of the tower to build. Need to set a
different one if multiple InferenceRunner are used.
"""
if isinstance(input, DataFlow):
input = FeedInput(input, infinite=False)
assert isinstance(input, InputSource), input
self._tower_name = tower_name
super(InferenceRunner, self).__init__(
input, infs, extra_hooks=extra_hooks)
def _build_hook(self, inf):
out_names = inf.get_fetches()
fetches = self._tower_handle.get_tensors(out_names)
return InferencerToHook(inf, fetches)
def _setup_graph(self): def _setup_graph(self):
# Use predict_tower in train config. either gpuid or -1 # Use predict_tower in train config. either gpuid or -1
tower_id = self.trainer.config.predict_tower[0] tower_id = self.trainer.config.predict_tower[0]
...@@ -95,61 +139,19 @@ class InferenceRunnerBase(Callback): ...@@ -95,61 +139,19 @@ class InferenceRunnerBase(Callback):
inf.setup_graph(self.trainer) inf.setup_graph(self.trainer)
self._input_callbacks.setup_graph(self.trainer) self._input_callbacks.setup_graph(self.trainer)
def _before_train(self):
self._hooks.extend(self._extra_hooks)
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
self._input_callbacks.before_train()
def _after_train(self):
self._input_callbacks.after_train()
@abstractmethod
def _build_hook(self, inf):
pass
def _trigger(self): def _trigger(self):
for inf in self.infs: for inf in self.infs:
inf.before_epoch() inf.before_epoch()
# iterate over the data, and run the hooked session # iterate over the data, and run the hooked session
self._input_source.reset_state() self._input_source.reset_state()
msg = "You might need to check your input implementation." with _inference_context():
try:
for _ in tqdm.trange(self._size, **get_tqdm_kwargs()): for _ in tqdm.trange(self._size, **get_tqdm_kwargs()):
self._hooked_sess.run(fetches=[]) self._hooked_sess.run(fetches=[])
except (StopIteration, tf.errors.CancelledError,
tf.errors.OutOfRangeError):
logger.error(
"[InferenceRunner] input stopped before reaching its size()! " + msg)
raise
for inf in self.infs: for inf in self.infs:
inf.trigger_epoch() inf.trigger_epoch()
class InferenceRunner(InferenceRunnerBase):
"""
A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`.
"""
def __init__(self, input, infs, tower_name='InferenceTower', extra_hooks=None):
"""
Args:
input (InputSource or DataFlow): The :class:`InputSource` to run
inference on. If given a DataFlow, will use :class:`FeedInput`.
infs (list): a list of :class:`Inferencer` instances.
"""
if isinstance(input, DataFlow):
input = FeedInput(input, infinite=False)
assert isinstance(input, InputSource), input
super(InferenceRunner, self).__init__(
input, infs, tower_name=tower_name, extra_hooks=extra_hooks)
def _build_hook(self, inf):
out_names = inf.get_fetches()
fetches = self._tower_handle.get_tensors(out_names)
return InferencerToHook(inf, fetches)
@deprecated("Just use InferenceRunner since it now accepts TensorInput!", "2017-11-11") @deprecated("Just use InferenceRunner since it now accepts TensorInput!", "2017-11-11")
def FeedfreeInferenceRunner(*args, **kwargs): def FeedfreeInferenceRunner(*args, **kwargs):
return InferenceRunner(*args, **kwargs) return InferenceRunner(*args, **kwargs)
...@@ -157,24 +159,26 @@ def FeedfreeInferenceRunner(*args, **kwargs): ...@@ -157,24 +159,26 @@ def FeedfreeInferenceRunner(*args, **kwargs):
class DataParallelInferenceRunner(InferenceRunnerBase): class DataParallelInferenceRunner(InferenceRunnerBase):
""" """
Inference by feeding datapoints in a data-parallel way to multiple GPUs. Inference with data-parallel support on multiple GPUs.
It will build one predict tower on each GPU, and run prediction
Doesn't support remapped InputSource for now. with a larger batch.
""" """
def __init__(self, input, infs, gpus): def __init__(self, input, infs, gpus):
""" """
Args: Args:
input (DataParallelFeedInput or DataFlow) input (DataFlow or QueueInput)
gpus (list[int]): list of GPU id gpus (list[int]): list of GPU id
""" """
self._tower_names = ['InferenceTower{}'.format(k) for k in range(len(gpus))] self._tower_names = ['InferenceTower{}'.format(k) for k in range(len(gpus))]
if isinstance(input, DataFlow): if isinstance(input, DataFlow):
input = QueueInput(input) input = QueueInput(input)
assert isinstance(input, QueueInput), input
super(DataParallelInferenceRunner, self).__init__(input, infs) super(DataParallelInferenceRunner, self).__init__(input, infs)
self._gpus = gpus self._gpus = gpus
def _setup_graph(self): def _setup_graph(self):
cbs = self._input_source.setup(self.trainer.model.get_inputs_desc()) cbs = self._input_source.setup(self.trainer.model.get_inputs_desc())
# build each predict tower
self._handles = [] self._handles = []
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for idx, t in enumerate(self._gpus): for idx, t in enumerate(self._gpus):
...@@ -184,11 +188,15 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -184,11 +188,15 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self.trainer.predictor_factory.build( self.trainer.predictor_factory.build(
tower_name, device, self._input_source)) tower_name, device, self._input_source))
# setup callbacksand hooks # setup callbacks and hooks
self._input_callbacks = Callbacks(cbs) self._input_callbacks = Callbacks(cbs)
self._hooks = [self._build_hook(inf) for inf in self.infs]
self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs] # InputSource might have hooks which break us.
self._hooks_parallel.extend(self._input_callbacks.get_hooks()) # e.g. hooks from StagingInputWrapper will force the consumption
# of nr_tower datapoints in every run.
input_hooks = self._input_callbacks.get_hooks()
self._hooks = [self._build_hook(inf) for inf in self.infs] + input_hooks
self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs] + input_hooks
for inf in self.infs: for inf in self.infs:
inf.setup_graph(self.trainer) inf.setup_graph(self.trainer)
...@@ -232,15 +240,15 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -232,15 +240,15 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self._input_source.reset_state() self._input_source.reset_state()
total = self._size total = self._size
nr_tower = len(self._gpus) nr_tower = len(self._gpus)
with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar: with _inference_context():
while total >= nr_tower: with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar:
self._parallel_hooked_sess.run(fetches=[]) while total >= nr_tower:
pbar.update(nr_tower) self._parallel_hooked_sess.run(fetches=[])
total -= nr_tower pbar.update(nr_tower)
# take care of the rest total -= nr_tower
while total > 0: # take care of the rest
self._hooked_sess.run(fetches=[]) for _ in range(total):
pbar.update(1) self._hooked_sess.run(fetches=[])
total -= 1 pbar.update(1)
for inf in self.infs: for inf in self.infs:
inf.trigger_epoch() inf.trigger_epoch()
...@@ -143,7 +143,7 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -143,7 +143,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
nr_proc (int): number of processes to use. nr_proc (int): number of processes to use.
hwm (int): the zmq "high-water mark" for both sender and receiver. hwm (int): the zmq "high-water mark" for both sender and receiver.
""" """
assert os.name != 'nt', "PrefetchDataZMQ doesn't support windows! Consider PrefetchData instead." assert os.name != 'nt', "PrefetchDataZMQ doesn't support windows! Consider PrefetchData instead."
super(PrefetchDataZMQ, self).__init__(ds) super(PrefetchDataZMQ, self).__init__(ds)
try: try:
self._size = ds.size() self._size = ds.size()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment