Commit b9a15df7 authored by Yuxin Wu's avatar Yuxin Wu

Use QueueInput in DataParallelInferenceRunner, correctness verified.

parent 5c241e09
...@@ -364,6 +364,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options): ...@@ -364,6 +364,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'get_predictors', 'get_predictors',
'vs_name_for_predictor', 'vs_name_for_predictor',
'dump_chkpt_vars', 'dump_chkpt_vars',
'VisualQA',
'ParamRestore']: 'ParamRestore']:
return True return True
if name in ['get_data', 'size', 'reset_state']: if name in ['get_data', 'size', 'reset_state']:
......
...@@ -74,7 +74,6 @@ class Model(ModelDesc): ...@@ -74,7 +74,6 @@ class Model(ModelDesc):
def get_data(name): def get_data(name):
isTrain = name == 'train' isTrain = name == 'train'
augmentors = fbresnet_augmentor(isTrain) augmentors = fbresnet_augmentor(isTrain)
augmentors.append(imgaug.ToUint8())
datadir = args.data datadir = args.data
return get_imagenet_dataflow( return get_imagenet_dataflow(
datadir, name, BATCH_SIZE, augmentors, dir_structure='original') datadir, name, BATCH_SIZE, augmentors, dir_structure='original')
......
...@@ -87,6 +87,7 @@ def get_imagenet_dataflow( ...@@ -87,6 +87,7 @@ def get_imagenet_dataflow(
http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html
""" """
assert name in ['train', 'val', 'test'] assert name in ['train', 'val', 'test']
assert datadir is not None
isTrain = name == 'train' isTrain = name == 'train'
cpu = min(30, multiprocessing.cpu_count()) cpu = min(30, multiprocessing.cpu_count())
if isTrain: if isTrain:
......
...@@ -20,12 +20,11 @@ from ..dataflow.base import DataFlow ...@@ -20,12 +20,11 @@ from ..dataflow.base import DataFlow
from ..graph_builder.input_source_base import InputSource from ..graph_builder.input_source_base import InputSource
from ..graph_builder.input_source import ( from ..graph_builder.input_source import (
FeedInput, DataParallelFeedInput) FeedInput, QueueInput)
from .base import Callback from .base import Callback
from .group import Callbacks from .group import Callbacks
from .inference import Inferencer from .inference import Inferencer
from .hooks import CallbackToHook
__all__ = ['InferenceRunner', 'FeedfreeInferenceRunner', __all__ = ['InferenceRunner', 'FeedfreeInferenceRunner',
'DataParallelInferenceRunner'] 'DataParallelInferenceRunner']
...@@ -151,7 +150,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -151,7 +150,7 @@ class InferenceRunner(InferenceRunnerBase):
return InferencerToHook(inf, fetches) return InferencerToHook(inf, fetches)
@deprecated("Just use InferenceRunner since it now accepts TensorInput!") @deprecated("Just use InferenceRunner since it now accepts TensorInput!", "2017-11-11")
def FeedfreeInferenceRunner(*args, **kwargs): def FeedfreeInferenceRunner(*args, **kwargs):
return InferenceRunner(*args, **kwargs) return InferenceRunner(*args, **kwargs)
...@@ -170,9 +169,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -170,9 +169,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
""" """
self._tower_names = ['InferenceTower{}'.format(k) for k in range(len(gpus))] self._tower_names = ['InferenceTower{}'.format(k) for k in range(len(gpus))]
if isinstance(input, DataFlow): if isinstance(input, DataFlow):
input = DataParallelFeedInput(input, self._tower_names) input = QueueInput(input)
assert isinstance(input, DataParallelFeedInput), input
super(DataParallelInferenceRunner, self).__init__(input, infs) super(DataParallelInferenceRunner, self).__init__(input, infs)
self._gpus = gpus self._gpus = gpus
...@@ -187,13 +184,15 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -187,13 +184,15 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self.trainer.predictor_factory.build( self.trainer.predictor_factory.build(
tower_name, device, self._input_source)) tower_name, device, self._input_source))
# setup feeds and hooks # setup callbacksand hooks
self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs] self._input_callbacks = Callbacks(cbs)
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
self._hooks_parallel.extend([CallbackToHook(cb) for cb in cbs]) self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs]
self._hooks_parallel.extend(self._input_callbacks.get_hooks())
for inf in self.infs: for inf in self.infs:
inf.setup_graph(self.trainer) inf.setup_graph(self.trainer)
self._input_callbacks.setup_graph(self.trainer)
class InferencerToHookDataParallel(InferencerToHook): class InferencerToHookDataParallel(InferencerToHook):
def __init__(self, inf, fetches, size): def __init__(self, inf, fetches, size):
...@@ -223,7 +222,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -223,7 +222,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
return InferencerToHook(inf, fetches) return InferencerToHook(inf, fetches)
def _before_train(self): def _before_train(self):
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks) super(DataParallelInferenceRunner, self)._before_train()
self._parallel_hooked_sess = HookedSession(self.trainer.sess, self._hooks_parallel) self._parallel_hooked_sess = HookedSession(self.trainer.sess, self._hooks_parallel)
def _trigger(self): def _trigger(self):
...@@ -239,16 +238,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -239,16 +238,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
pbar.update(nr_tower) pbar.update(nr_tower)
total -= nr_tower total -= nr_tower
# take care of the rest # take care of the rest
try: while total > 0:
while total > 0: self._hooked_sess.run(fetches=[])
# TODO XXX doesn't support remap pbar.update(1)
feed = self._input_source.next_feed(cnt=1) total -= 1
self._hooked_sess.run(fetches=[], feed_dict=feed)
pbar.update(1)
total -= 1
except AttributeError:
logger.error(
"[DataParallelInferenceRunner] doesn't support InputSource wrappers very well!")
logger.error("[DataParallelInferenceRunner] Skipping the rest of the datapoints ...")
for inf in self.infs: for inf in self.infs:
inf.trigger_epoch() inf.trigger_epoch()
...@@ -185,7 +185,7 @@ class HumanHyperParamSetter(HyperParamSetter): ...@@ -185,7 +185,7 @@ class HumanHyperParamSetter(HyperParamSetter):
""" """
super(HumanHyperParamSetter, self).__init__(param) super(HumanHyperParamSetter, self).__init__(param)
self.file_name = os.path.join(logger.LOG_DIR, file_name) self.file_name = os.path.join(logger.LOG_DIR, file_name)
logger.info("Use {} to control hyperparam {}.".format( logger.info("Use {} to set hyperparam: '{}'.".format(
self.file_name, self.param.readable_name)) self.file_name, self.param.readable_name))
def _get_value_to_set(self): def _get_value_to_set(self):
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from ..base import DataFlow from ..base import DataFlow
from ...utils.timer import timed_operation from ...utils.timer import timed_operation
from ...utils import logger
from six.moves import zip, map from six.moves import zip, map
from collections import Counter from collections import Counter
import json import json
...@@ -26,6 +27,7 @@ class VisualQA(DataFlow): ...@@ -26,6 +27,7 @@ class VisualQA(DataFlow):
""" """
def __init__(self, question_file, annotation_file): def __init__(self, question_file, annotation_file):
logger.warn("dataset.VisualQA is deprecated!")
with timed_operation('Reading VQA JSON file'): with timed_operation('Reading VQA JSON file'):
qobj, aobj = list(map(read_json, [question_file, annotation_file])) qobj, aobj = list(map(read_json, [question_file, annotation_file]))
self.task_type = qobj['task_type'] self.task_type = qobj['task_type']
......
...@@ -275,7 +275,7 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -275,7 +275,7 @@ class ThreadedMapData(ProxyDataFlow):
dp = self.queue_get_stoppable(self.inq) dp = self.queue_get_stoppable(self.inq)
dp = self.func(dp) dp = self.func(dp)
if dp is not None: if dp is not None:
self.queue_put_stoppable(self.outq, dp) self.outq.put(dp)
else: else:
assert not self._strict, \ assert not self._strict, \
"[ThreadedMapData] Map function cannot return None when strict mode is used." "[ThreadedMapData] Map function cannot return None when strict mode is used."
...@@ -345,3 +345,8 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -345,3 +345,8 @@ class ThreadedMapData(ProxyDataFlow):
for _ in range(self.buffer_size): for _ in range(self.buffer_size):
self._in_queue.put(next(self._iter)) self._in_queue.put(next(self._iter))
yield self._out_queue.get() yield self._out_queue.get()
def __del__(self):
for p in self._threads:
p.stop()
p.join()
...@@ -192,13 +192,13 @@ class EnqueueThread(ShareSessionThread): ...@@ -192,13 +192,13 @@ class EnqueueThread(ShareSessionThread):
except (tf.errors.CancelledError, tf.errors.OutOfRangeError, DataFlowTerminated): except (tf.errors.CancelledError, tf.errors.OutOfRangeError, DataFlowTerminated):
pass pass
except Exception: except Exception:
logger.exception("Exception in EnqueueThread:") logger.exception("Exception in {}:".format(self.name))
finally: finally:
try: try:
self.close_op.run() self.close_op.run()
except Exception: except Exception:
pass pass
logger.info("EnqueueThread Exited.") logger.info("{} Exited.".format(self.name))
class QueueInput(FeedfreeInput): class QueueInput(FeedfreeInput):
...@@ -234,6 +234,10 @@ class QueueInput(FeedfreeInput): ...@@ -234,6 +234,10 @@ class QueueInput(FeedfreeInput):
self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs) self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs)
def _create_ema_callback(self): def _create_ema_callback(self):
"""
Create a hook-only callback which maintain EMA of the queue size.
Also tf.summary.scalar the EMA.
"""
with self.cached_name_scope(): with self.cached_name_scope():
# in TF there is no API to get queue capacity, so we can only summary the size # in TF there is no API to get queue capacity, so we can only summary the size
size = tf.cast(self.queue.size(), tf.float32, name='queue_size') size = tf.cast(self.queue.size(), tf.float32, name='queue_size')
......
...@@ -60,7 +60,8 @@ class PredictorFactory(object): ...@@ -60,7 +60,8 @@ class PredictorFactory(object):
input (InputSource): must be setup already. If None, will use InputDesc from the model. input (InputSource): must be setup already. If None, will use InputDesc from the model.
""" """
logger.info("Building predictor tower '{}' on device {} ...".format(tower_name, device)) logger.info("Building predictor tower '{}' on device {} ...".format(tower_name, device))
assert tower_name not in self._names_built assert tower_name not in self._names_built, \
"Prediction tower with name '{}' already exists!".format(tower_name)
with tf.device(device), \ with tf.device(device), \
TowerContext(tower_name, is_training=False), \ TowerContext(tower_name, is_training=False), \
......
...@@ -104,6 +104,9 @@ class SaverRestore(SessionInit): ...@@ -104,6 +104,9 @@ class SaverRestore(SessionInit):
prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint
ignore (list[str]): list of tensor names that should be ignored during loading, e.g. learning-rate ignore (list[str]): list of tensor names that should be ignored during loading, e.g. learning-rate
""" """
if model_path.endswith('.npy') or model_path.endswith('.npz'):
logger.warn("SaverRestore expect a TF checkpoint, but got a model path '{}'.".format(model_path) +
" To load from a dict, use 'DictRestore'.")
model_path = get_checkpoint_path(model_path) model_path = get_checkpoint_path(model_path)
self.path = model_path self.path = model_path
self.prefix = prefix self.prefix = prefix
...@@ -192,6 +195,7 @@ class DictRestore(SessionInit): ...@@ -192,6 +195,7 @@ class DictRestore(SessionInit):
Args: Args:
param_dict (dict): a dict of {name: value} param_dict (dict): a dict of {name: value}
""" """
assert isinstance(param_dict, dict), type(param_dict)
# use varname (with :0) for consistency # use varname (with :0) for consistency
self.prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(param_dict)} self.prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(param_dict)}
...@@ -220,7 +224,7 @@ class DictRestore(SessionInit): ...@@ -220,7 +224,7 @@ class DictRestore(SessionInit):
upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect}) upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect})
@deprecated("Use `DictRestore` instead!", "2017-06-01") @deprecated("Use `DictRestore` instead!", "2017-09-01")
def ParamRestore(d): def ParamRestore(d):
return DictRestore(d) return DictRestore(d)
......
...@@ -62,8 +62,8 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase): ...@@ -62,8 +62,8 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
Returns: Returns:
List of outputs of ``func``, evaluated on each tower. List of outputs of ``func``, evaluated on each tower.
""" """
logger.info("Training a model of {} tower".format(len(towers)))
if len(towers) > 1: if len(towers) > 1:
logger.info("Training a model of {} towers".format(len(towers)))
_check_tf_version() _check_tf_version()
ret = [] ret = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment