Commit f7ab74a3 authored by Yuxin Wu's avatar Yuxin Wu

make StagingInput dependency-safe

parent 6e5ed1f1
......@@ -117,7 +117,7 @@ def get_config(cifar_classnum):
return lr * 0.31
return TrainConfig(
model=Model(cifar_classnum),
dataflow=dataset_train,
data=QueueInput(dataset_train),
callbacks=[
ModelSaver(),
InferenceRunner(dataset_test,
......@@ -131,7 +131,7 @@ def get_config(cifar_classnum):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.', required=True)
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
parser.add_argument('--classnum', help='10 for cifar10 or 100 for cifar100',
type=int, default=10)
......@@ -147,6 +147,6 @@ if __name__ == '__main__':
config.session_init = SaverRestore(args.load)
num_gpu = get_num_gpu()
trainer = QueueInputTrainer() if num_gpu <= 1 \
trainer = SimpleTrainer() if num_gpu <= 1 \
else SyncMultiGPUTrainerParameterServer(num_gpu)
launch_train_with_config(config, trainer)
......@@ -18,9 +18,10 @@ from ..dataflow import DataFlow, MapData, RepeatedData, DataFlowTerminated
from ..tfutils.summary import add_moving_summary
from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
from ..tfutils.dependency import dependency_of_fetches
from ..utils import logger
from ..utils.concurrency import ShareSessionThread
from ..utils.develop import log_deprecated
from ..utils.develop import log_deprecated, deprecated
from ..callbacks.base import Callback, CallbackFactory
from ..callbacks.graph import RunOp
......@@ -117,7 +118,8 @@ class EnqueueThread(ShareSessionThread):
self.op = self.queue.enqueue(self.placehdrs)
self.close_op = self.queue.close(cancel_pending_enqueues=True)
self._lock = threading.Lock()
self._running = threading.Event()
self._running.set()
# self._size = queue.size()
def run(self):
......@@ -126,8 +128,8 @@ class EnqueueThread(ShareSessionThread):
self.reinitialize_dataflow()
while True:
# pausable loop
self._lock.acquire()
self._lock.release()
if not self._running.is_set():
self._running.wait()
dp = next(self._itr)
feed = dict(zip(self.placehdrs, dp))
......@@ -151,10 +153,10 @@ class EnqueueThread(ShareSessionThread):
self._itr = self.dataflow.get_data()
def pause(self):
self._lock.acquire()
self._running.clear()
def resume(self):
self._lock.release()
self._running.set()
class QueueInput(FeedfreeInput):
......@@ -486,7 +488,7 @@ class StagingInput(FeedfreeInput):
it requires that all outputs ever produced by this InputSource will be fetched together.
This means that in multi-GPU training, you should ensure that each call on `hooked_sess.run`
depends on all input tensors on all GPUs.
depends on either all input tensors on all GPUs, or no input tensors at all.
As a result you cannot use this InputSource for :class:`InferenceRunner`.
"""
class StagingCallback(Callback):
......@@ -503,6 +505,7 @@ class StagingInput(FeedfreeInput):
self.stage_op = self._input._get_stage_op()
unstage_ops = self._input._get_unstage_ops()
unstage_op = tf.group(unstage_ops, name='unstage_all')
self._check_dependency_op = unstage_ops[0]
self.fetches = tf.train.SessionRunArgs(
fetches=[self.stage_op, unstage_op])
......@@ -510,8 +513,8 @@ class StagingInput(FeedfreeInput):
logger.info("Pre-filling StagingArea ...")
for k in range(self.nr_stage):
self.stage_op.run()
logger.info("Successfully put {} element{} to StagingArea.".format(
self.nr_stage, "s" if self.nr_stage > 1 else ""))
logger.info("{} element{} put into StagingArea.".format(
self.nr_stage, "s were" if self.nr_stage > 1 else " was"))
def _before_run(self, ctx):
# This has to happen once, right before the first iteration.
......@@ -519,6 +522,9 @@ class StagingInput(FeedfreeInput):
if not self._initialized:
self._initialized = True
self._prefill()
# Only step the stagingarea when the input is evaluated in this sess.run
fetches = ctx.original_args.fetches
if dependency_of_fetches(fetches, self._check_dependency_op):
return self.fetches
def __init__(self, input, towers=None, nr_stage=1, device=None):
......@@ -568,9 +574,10 @@ class StagingInput(FeedfreeInput):
yield
def _get_input_tensors(self):
with self.cached_name_scope(), self._device_ctx():
inputs = self._input.get_input_tensors()
with self._device_ctx():
with self.cached_name_scope():
# Putting variables to stagingarea will cause trouble
dtypes = []
for idx in range(len(inputs)):
......@@ -589,6 +596,7 @@ class StagingInput(FeedfreeInput):
outputs = stage.get()
if isinstance(outputs, tf.Tensor): # when size=1, TF doesn't return a list
outputs = [outputs]
for vin, vout in zip(inputs, outputs):
vout.set_shape(vin.get_shape())
self._unstage_ops.append(outputs)
......@@ -617,4 +625,6 @@ class StagingInput(FeedfreeInput):
run_step=True)
StagingInputWrapper = StagingInput
@deprecated("Renamed to StagingInput", "2018-08-01")
def StagingInputWrapper(*args, **kwargs):
return StagingInput(*args, **kwargs)
......@@ -7,6 +7,10 @@ from ..utils.argtools import graph_memoized
Utils about parsing dependencies in the graph.
"""
__all__ = [
'dependency_of_targets', 'dependency_of_fetches'
]
@graph_memoized
def dependency_of_targets(targets, op):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment