Commit 74c80d57 authored by Yuxin Wu's avatar Yuxin Wu

refactor InferenceRunner. (fix #139)

parent b9a15df7
......@@ -7,7 +7,7 @@ Neural Network Toolbox on TensorFlow.
See some [examples](examples) to learn about the framework:
### Vision:
+ [Train ResNet on ImageNet / Cifar10 / SVHN](examples/ResNet)
+ [Multi-GPU training of ResNet on ImageNet](examples/ResNet)
+ [Generative Adversarial Network(GAN) variants](examples/GAN), including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image, CycleGAN.
+ [DoReFa-Net: train binary / low-bitwidth CNN on ImageNet](examples/DoReFa-Net)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection(HED)](examples/HED)
......@@ -43,7 +43,7 @@ It's Yet Another TF wrapper, but different in:
+ Data-parallel distributed training is off-the-shelf to use. It is as slow as Google's official benchmark.
3. Focus on large datasets.
+ It's painful to read/preprocess data through TF. Use __DataFlow__ to load large datasets (e.g. ImageNet) in __pure Python__ with multi-process prefetch.
+ It's painful to read/preprocess data through TF. Use __DataFlow__ to load large datasets (e.g. ImageNet) in __pure Python__ with autoparallelization.
+ DataFlow has a unified interface, so you can compose and reuse them to perform complex preprocessing.
4. Interface of extensible __Callbacks__.
......
......@@ -8,9 +8,8 @@ from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession
import itertools
from abc import ABCMeta, abstractmethod
from contextlib import contextmanager
import tqdm
import six
from six.moves import range
from ..utils import logger
......@@ -42,16 +41,29 @@ class InferencerToHook(tf.train.SessionRunHook):
self._inf.on_fetches(run_values.results)
@six.add_metaclass(ABCMeta)
@contextmanager
def _inference_context():
msg = "You might need to check your input implementation."
try:
yield
except (StopIteration,
tf.errors.CancelledError,
tf.errors.OutOfRangeError):
logger.error(
"[InferenceRunner] input stopped before reaching its size()! " + msg)
raise
class InferenceRunnerBase(Callback):
""" Base methods for inference runner"""
def __init__(self, input, infs, tower_name='InferenceTower', extra_hooks=None, prefix=None):
""" Base class for inference runner.
Please note that InferenceRunner will use `input.size()` to determine
how much iterations to run, so you want it to be accurate.
"""
def __init__(self, input, infs, extra_hooks=None):
"""
Args:
input (InputSource): the input to use. Must have ``size()``.
infs (list[Inferencer]): list of :class:`Inferencer` to run.
tower_name(str): name scope to build the tower. Must be set
differently if more than one :class:`InferenceRunner` are used.
extra_hooks (list[SessionRunHook]): extra :class:`SessionRunHook` to run with the evaluation.
"""
self._input_source = input
......@@ -66,14 +78,46 @@ class InferenceRunnerBase(Callback):
self._size = input.size()
except NotImplementedError:
raise ValueError("Input used in InferenceRunner must have a size!")
self._tower_name = tower_name
if prefix is not None:
self._tower_name = 'InferenceTower' + prefix
if extra_hooks is None:
extra_hooks = []
self._extra_hooks = extra_hooks
def _before_train(self):
self._hooks.extend(self._extra_hooks)
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
self._input_callbacks.before_train()
def _after_train(self):
self._input_callbacks.after_train()
class InferenceRunner(InferenceRunnerBase):
"""
A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`.
"""
def __init__(self, input, infs, tower_name='InferenceTower', extra_hooks=None):
"""
Args:
input (InputSource or DataFlow): The :class:`InputSource` to run
inference on. If given a DataFlow, will use :class:`FeedInput`.
infs (list): a list of :class:`Inferencer` instances.
tower_name (str): the name scope of the tower to build. Need to set a
different one if multiple InferenceRunner are used.
"""
if isinstance(input, DataFlow):
input = FeedInput(input, infinite=False)
assert isinstance(input, InputSource), input
self._tower_name = tower_name
super(InferenceRunner, self).__init__(
input, infs, extra_hooks=extra_hooks)
def _build_hook(self, inf):
out_names = inf.get_fetches()
fetches = self._tower_handle.get_tensors(out_names)
return InferencerToHook(inf, fetches)
def _setup_graph(self):
# Use predict_tower in train config. either gpuid or -1
tower_id = self.trainer.config.predict_tower[0]
......@@ -95,61 +139,19 @@ class InferenceRunnerBase(Callback):
inf.setup_graph(self.trainer)
self._input_callbacks.setup_graph(self.trainer)
def _before_train(self):
self._hooks.extend(self._extra_hooks)
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
self._input_callbacks.before_train()
def _after_train(self):
self._input_callbacks.after_train()
@abstractmethod
def _build_hook(self, inf):
pass
def _trigger(self):
for inf in self.infs:
inf.before_epoch()
# iterate over the data, and run the hooked session
self._input_source.reset_state()
msg = "You might need to check your input implementation."
try:
with _inference_context():
for _ in tqdm.trange(self._size, **get_tqdm_kwargs()):
self._hooked_sess.run(fetches=[])
except (StopIteration, tf.errors.CancelledError,
tf.errors.OutOfRangeError):
logger.error(
"[InferenceRunner] input stopped before reaching its size()! " + msg)
raise
for inf in self.infs:
inf.trigger_epoch()
class InferenceRunner(InferenceRunnerBase):
"""
A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`.
"""
def __init__(self, input, infs, tower_name='InferenceTower', extra_hooks=None):
"""
Args:
input (InputSource or DataFlow): The :class:`InputSource` to run
inference on. If given a DataFlow, will use :class:`FeedInput`.
infs (list): a list of :class:`Inferencer` instances.
"""
if isinstance(input, DataFlow):
input = FeedInput(input, infinite=False)
assert isinstance(input, InputSource), input
super(InferenceRunner, self).__init__(
input, infs, tower_name=tower_name, extra_hooks=extra_hooks)
def _build_hook(self, inf):
out_names = inf.get_fetches()
fetches = self._tower_handle.get_tensors(out_names)
return InferencerToHook(inf, fetches)
@deprecated("Just use InferenceRunner since it now accepts TensorInput!", "2017-11-11")
def FeedfreeInferenceRunner(*args, **kwargs):
return InferenceRunner(*args, **kwargs)
......@@ -157,24 +159,26 @@ def FeedfreeInferenceRunner(*args, **kwargs):
class DataParallelInferenceRunner(InferenceRunnerBase):
"""
Inference by feeding datapoints in a data-parallel way to multiple GPUs.
Doesn't support remapped InputSource for now.
Inference with data-parallel support on multiple GPUs.
It will build one predict tower on each GPU, and run prediction
with a larger batch.
"""
def __init__(self, input, infs, gpus):
"""
Args:
input (DataParallelFeedInput or DataFlow)
input (DataFlow or QueueInput)
gpus (list[int]): list of GPU id
"""
self._tower_names = ['InferenceTower{}'.format(k) for k in range(len(gpus))]
if isinstance(input, DataFlow):
input = QueueInput(input)
assert isinstance(input, QueueInput), input
super(DataParallelInferenceRunner, self).__init__(input, infs)
self._gpus = gpus
def _setup_graph(self):
cbs = self._input_source.setup(self.trainer.model.get_inputs_desc())
# build each predict tower
self._handles = []
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for idx, t in enumerate(self._gpus):
......@@ -184,11 +188,15 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self.trainer.predictor_factory.build(
tower_name, device, self._input_source))
# setup callbacksand hooks
# setup callbacks and hooks
self._input_callbacks = Callbacks(cbs)
self._hooks = [self._build_hook(inf) for inf in self.infs]
self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs]
self._hooks_parallel.extend(self._input_callbacks.get_hooks())
# InputSource might have hooks which break us.
# e.g. hooks from StagingInputWrapper will force the consumption
# of nr_tower datapoints in every run.
input_hooks = self._input_callbacks.get_hooks()
self._hooks = [self._build_hook(inf) for inf in self.infs] + input_hooks
self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs] + input_hooks
for inf in self.infs:
inf.setup_graph(self.trainer)
......@@ -232,15 +240,15 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self._input_source.reset_state()
total = self._size
nr_tower = len(self._gpus)
with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar:
while total >= nr_tower:
self._parallel_hooked_sess.run(fetches=[])
pbar.update(nr_tower)
total -= nr_tower
# take care of the rest
while total > 0:
self._hooked_sess.run(fetches=[])
pbar.update(1)
total -= 1
with _inference_context():
with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar:
while total >= nr_tower:
self._parallel_hooked_sess.run(fetches=[])
pbar.update(nr_tower)
total -= nr_tower
# take care of the rest
for _ in range(total):
self._hooked_sess.run(fetches=[])
pbar.update(1)
for inf in self.infs:
inf.trigger_epoch()
......@@ -143,7 +143,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
nr_proc (int): number of processes to use.
hwm (int): the zmq "high-water mark" for both sender and receiver.
"""
assert os.name != 'nt', "PrefetchDataZMQ doesn't support windows! Consider PrefetchData instead."
assert os.name != 'nt', "PrefetchDataZMQ doesn't support windows! Consider PrefetchData instead."
super(PrefetchDataZMQ, self).__init__(ds)
try:
self._size = ds.size()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment