Commit b9a15df7 authored by Yuxin Wu's avatar Yuxin Wu

Use QueueInput in DataParallelInferenceRunner, correctness verified.

parent 5c241e09
......@@ -364,6 +364,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'get_predictors',
'vs_name_for_predictor',
'dump_chkpt_vars',
'VisualQA',
'ParamRestore']:
return True
if name in ['get_data', 'size', 'reset_state']:
......
......@@ -74,7 +74,6 @@ class Model(ModelDesc):
def get_data(name):
isTrain = name == 'train'
augmentors = fbresnet_augmentor(isTrain)
augmentors.append(imgaug.ToUint8())
datadir = args.data
return get_imagenet_dataflow(
datadir, name, BATCH_SIZE, augmentors, dir_structure='original')
......
......@@ -87,6 +87,7 @@ def get_imagenet_dataflow(
http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html
"""
assert name in ['train', 'val', 'test']
assert datadir is not None
isTrain = name == 'train'
cpu = min(30, multiprocessing.cpu_count())
if isTrain:
......
......@@ -20,12 +20,11 @@ from ..dataflow.base import DataFlow
from ..graph_builder.input_source_base import InputSource
from ..graph_builder.input_source import (
FeedInput, DataParallelFeedInput)
FeedInput, QueueInput)
from .base import Callback
from .group import Callbacks
from .inference import Inferencer
from .hooks import CallbackToHook
__all__ = ['InferenceRunner', 'FeedfreeInferenceRunner',
'DataParallelInferenceRunner']
......@@ -151,7 +150,7 @@ class InferenceRunner(InferenceRunnerBase):
return InferencerToHook(inf, fetches)
@deprecated("Just use InferenceRunner since it now accepts TensorInput!")
@deprecated("Just use InferenceRunner since it now accepts TensorInput!", "2017-11-11")
def FeedfreeInferenceRunner(*args, **kwargs):
return InferenceRunner(*args, **kwargs)
......@@ -170,9 +169,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
"""
self._tower_names = ['InferenceTower{}'.format(k) for k in range(len(gpus))]
if isinstance(input, DataFlow):
input = DataParallelFeedInput(input, self._tower_names)
assert isinstance(input, DataParallelFeedInput), input
input = QueueInput(input)
super(DataParallelInferenceRunner, self).__init__(input, infs)
self._gpus = gpus
......@@ -187,13 +184,15 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self.trainer.predictor_factory.build(
tower_name, device, self._input_source))
# setup feeds and hooks
self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs]
# setup callbacksand hooks
self._input_callbacks = Callbacks(cbs)
self._hooks = [self._build_hook(inf) for inf in self.infs]
self._hooks_parallel.extend([CallbackToHook(cb) for cb in cbs])
self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs]
self._hooks_parallel.extend(self._input_callbacks.get_hooks())
for inf in self.infs:
inf.setup_graph(self.trainer)
self._input_callbacks.setup_graph(self.trainer)
class InferencerToHookDataParallel(InferencerToHook):
def __init__(self, inf, fetches, size):
......@@ -223,7 +222,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
return InferencerToHook(inf, fetches)
def _before_train(self):
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
super(DataParallelInferenceRunner, self)._before_train()
self._parallel_hooked_sess = HookedSession(self.trainer.sess, self._hooks_parallel)
def _trigger(self):
......@@ -239,16 +238,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
pbar.update(nr_tower)
total -= nr_tower
# take care of the rest
try:
while total > 0:
# TODO XXX doesn't support remap
feed = self._input_source.next_feed(cnt=1)
self._hooked_sess.run(fetches=[], feed_dict=feed)
pbar.update(1)
total -= 1
except AttributeError:
logger.error(
"[DataParallelInferenceRunner] doesn't support InputSource wrappers very well!")
logger.error("[DataParallelInferenceRunner] Skipping the rest of the datapoints ...")
while total > 0:
self._hooked_sess.run(fetches=[])
pbar.update(1)
total -= 1
for inf in self.infs:
inf.trigger_epoch()
......@@ -185,7 +185,7 @@ class HumanHyperParamSetter(HyperParamSetter):
"""
super(HumanHyperParamSetter, self).__init__(param)
self.file_name = os.path.join(logger.LOG_DIR, file_name)
logger.info("Use {} to control hyperparam {}.".format(
logger.info("Use {} to set hyperparam: '{}'.".format(
self.file_name, self.param.readable_name))
def _get_value_to_set(self):
......
......@@ -5,6 +5,7 @@
from ..base import DataFlow
from ...utils.timer import timed_operation
from ...utils import logger
from six.moves import zip, map
from collections import Counter
import json
......@@ -26,6 +27,7 @@ class VisualQA(DataFlow):
"""
def __init__(self, question_file, annotation_file):
logger.warn("dataset.VisualQA is deprecated!")
with timed_operation('Reading VQA JSON file'):
qobj, aobj = list(map(read_json, [question_file, annotation_file]))
self.task_type = qobj['task_type']
......
......@@ -275,7 +275,7 @@ class ThreadedMapData(ProxyDataFlow):
dp = self.queue_get_stoppable(self.inq)
dp = self.func(dp)
if dp is not None:
self.queue_put_stoppable(self.outq, dp)
self.outq.put(dp)
else:
assert not self._strict, \
"[ThreadedMapData] Map function cannot return None when strict mode is used."
......@@ -345,3 +345,8 @@ class ThreadedMapData(ProxyDataFlow):
for _ in range(self.buffer_size):
self._in_queue.put(next(self._iter))
yield self._out_queue.get()
def __del__(self):
for p in self._threads:
p.stop()
p.join()
......@@ -192,13 +192,13 @@ class EnqueueThread(ShareSessionThread):
except (tf.errors.CancelledError, tf.errors.OutOfRangeError, DataFlowTerminated):
pass
except Exception:
logger.exception("Exception in EnqueueThread:")
logger.exception("Exception in {}:".format(self.name))
finally:
try:
self.close_op.run()
except Exception:
pass
logger.info("EnqueueThread Exited.")
logger.info("{} Exited.".format(self.name))
class QueueInput(FeedfreeInput):
......@@ -234,6 +234,10 @@ class QueueInput(FeedfreeInput):
self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs)
def _create_ema_callback(self):
"""
Create a hook-only callback which maintain EMA of the queue size.
Also tf.summary.scalar the EMA.
"""
with self.cached_name_scope():
# in TF there is no API to get queue capacity, so we can only summary the size
size = tf.cast(self.queue.size(), tf.float32, name='queue_size')
......
......@@ -60,7 +60,8 @@ class PredictorFactory(object):
input (InputSource): must be setup already. If None, will use InputDesc from the model.
"""
logger.info("Building predictor tower '{}' on device {} ...".format(tower_name, device))
assert tower_name not in self._names_built
assert tower_name not in self._names_built, \
"Prediction tower with name '{}' already exists!".format(tower_name)
with tf.device(device), \
TowerContext(tower_name, is_training=False), \
......
......@@ -104,6 +104,9 @@ class SaverRestore(SessionInit):
prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint
ignore (list[str]): list of tensor names that should be ignored during loading, e.g. learning-rate
"""
if model_path.endswith('.npy') or model_path.endswith('.npz'):
logger.warn("SaverRestore expect a TF checkpoint, but got a model path '{}'.".format(model_path) +
" To load from a dict, use 'DictRestore'.")
model_path = get_checkpoint_path(model_path)
self.path = model_path
self.prefix = prefix
......@@ -192,6 +195,7 @@ class DictRestore(SessionInit):
Args:
param_dict (dict): a dict of {name: value}
"""
assert isinstance(param_dict, dict), type(param_dict)
# use varname (with :0) for consistency
self.prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(param_dict)}
......@@ -220,7 +224,7 @@ class DictRestore(SessionInit):
upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect})
@deprecated("Use `DictRestore` instead!", "2017-06-01")
@deprecated("Use `DictRestore` instead!", "2017-09-01")
def ParamRestore(d):
return DictRestore(d)
......
......@@ -62,8 +62,8 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
Returns:
List of outputs of ``func``, evaluated on each tower.
"""
logger.info("Training a model of {} tower".format(len(towers)))
if len(towers) > 1:
logger.info("Training a model of {} towers".format(len(towers)))
_check_tf_version()
ret = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment