Commit 21c7f94a authored by Yuxin Wu's avatar Yuxin Wu

InferenceRunner now works with QueueInput (#139)

parent ac2031df
...@@ -44,8 +44,8 @@ class CallbackTimeLogger(object): ...@@ -44,8 +44,8 @@ class CallbackTimeLogger(object):
class Callbacks(Callback): class Callbacks(Callback):
""" """
A container to hold all callbacks, and execute them in the right order A container to hold all callbacks, and trigger them iteratively.
(e.g. :class:`StatPrinter` will be executed at last). Note that it does nothing to before_run/after_run.
""" """
def __init__(self, cbs): def __init__(self, cbs):
......
...@@ -20,9 +20,10 @@ from ..dataflow.base import DataFlow, DataFlowTerminated ...@@ -20,9 +20,10 @@ from ..dataflow.base import DataFlow, DataFlowTerminated
from ..graph_builder.input_source_base import InputSource from ..graph_builder.input_source_base import InputSource
from ..graph_builder.input_source import ( from ..graph_builder.input_source import (
FeedInput, DataParallelFeedInput, FeedfreeInput, TensorInput) FeedInput, DataParallelFeedInput)
from .base import Callback from .base import Callback
from .group import Callbacks
from .inference import Inferencer from .inference import Inferencer
from .hooks import CallbackToHook from .hooks import CallbackToHook
...@@ -79,20 +80,29 @@ class InferenceRunnerBase(Callback): ...@@ -79,20 +80,29 @@ class InferenceRunnerBase(Callback):
tower_id = self.trainer.config.predict_tower[0] tower_id = self.trainer.config.predict_tower[0]
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0' device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0'
self._input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc()) input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc())
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self._tower_handle = self.trainer.predictor_factory.build( self._tower_handle = self.trainer.predictor_factory.build(
self._tower_name, device, self._input_source) self._tower_name, device, self._input_source)
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
self._hooks.extend([CallbackToHook(cb) for cb in self._input_callbacks]) # trigger_{step,epoch}, {before,after}_epoch is ignored.
# We assume that InputSource callbacks won't use these methods
self._input_callbacks = Callbacks(input_callbacks)
self._hooks.extend(self._input_callbacks.get_hooks())
for inf in self.infs: for inf in self.infs:
inf.setup_graph(self.trainer) inf.setup_graph(self.trainer)
self._input_callbacks.setup_graph(self.trainer)
def _before_train(self): def _before_train(self):
self._hooks.extend(self._extra_hooks) self._hooks.extend(self._extra_hooks)
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks) self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
self._input_callbacks.before_train()
def _after_train(self):
self._input_callbacks.after_train()
@abstractmethod @abstractmethod
def _build_hook(self, inf): def _build_hook(self, inf):
...@@ -108,9 +118,11 @@ class InferenceRunnerBase(Callback): ...@@ -108,9 +118,11 @@ class InferenceRunnerBase(Callback):
try: try:
for _ in tqdm.trange(self._size, **get_tqdm_kwargs()): for _ in tqdm.trange(self._size, **get_tqdm_kwargs()):
self._hooked_sess.run(fetches=[]) self._hooked_sess.run(fetches=[])
except (StopIteration, DataFlowTerminated): except (StopIteration, DataFlowTerminated,
logger.exception( tf.errors.CancelledError, tf.errors.OutOfRangeError):
logger.error(
"[InferenceRunner] input stopped before reaching its size()! " + msg) "[InferenceRunner] input stopped before reaching its size()! " + msg)
raise
for inf in self.infs: for inf in self.infs:
inf.trigger_epoch() inf.trigger_epoch()
...@@ -130,8 +142,6 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -130,8 +142,6 @@ class InferenceRunner(InferenceRunnerBase):
if isinstance(input, DataFlow): if isinstance(input, DataFlow):
input = FeedInput(input, infinite=False) input = FeedInput(input, infinite=False)
assert isinstance(input, InputSource), input assert isinstance(input, InputSource), input
if isinstance(input, FeedfreeInput): # TODO support other input
assert isinstance(input, TensorInput), "InferenceRunner only accepts TensorInput or FeedInput!"
super(InferenceRunner, self).__init__( super(InferenceRunner, self).__init__(
input, infs, tower_name=tower_name, extra_hooks=extra_hooks) input, infs, tower_name=tower_name, extra_hooks=extra_hooks)
......
...@@ -169,7 +169,7 @@ class FeedfreeInput(InputSource): ...@@ -169,7 +169,7 @@ class FeedfreeInput(InputSource):
class EnqueueThread(ShareSessionThread): class EnqueueThread(ShareSessionThread):
def __init__(self, queue, ds, placehdrs): def __init__(self, queue, ds, placehdrs):
super(EnqueueThread, self).__init__() super(EnqueueThread, self).__init__()
self.name = 'EnqueueThread' self.name = 'EnqueueThread ' + queue.name
self.daemon = True self.daemon = True
self.dataflow = ds self.dataflow = ds
...@@ -222,7 +222,6 @@ class QueueInput(FeedfreeInput): ...@@ -222,7 +222,6 @@ class QueueInput(FeedfreeInput):
return self.ds.size() return self.ds.size()
def _setup(self, inputs): def _setup(self, inputs):
logger.info("Setting up the queue for CPU prefetching ...")
self._input_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
assert len(self._input_placehdrs) > 0, \ assert len(self._input_placehdrs) > 0, \
"QueueInput has to be used with some inputs!" "QueueInput has to be used with some inputs!"
...@@ -231,6 +230,7 @@ class QueueInput(FeedfreeInput): ...@@ -231,6 +230,7 @@ class QueueInput(FeedfreeInput):
self.queue = tf.FIFOQueue( self.queue = tf.FIFOQueue(
50, [x.dtype for x in self._input_placehdrs], 50, [x.dtype for x in self._input_placehdrs],
name='input_queue') name='input_queue')
logger.info("Setting up the queue '{}' for CPU prefetching ...".format(self.queue.name))
self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs) self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs)
def _create_ema_callback(self): def _create_ema_callback(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment