Commit f7ab74a3 authored by Yuxin Wu's avatar Yuxin Wu

make StagingInput dependency-safe

parent 6e5ed1f1
...@@ -117,7 +117,7 @@ def get_config(cifar_classnum): ...@@ -117,7 +117,7 @@ def get_config(cifar_classnum):
return lr * 0.31 return lr * 0.31
return TrainConfig( return TrainConfig(
model=Model(cifar_classnum), model=Model(cifar_classnum),
dataflow=dataset_train, data=QueueInput(dataset_train),
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
InferenceRunner(dataset_test, InferenceRunner(dataset_test,
...@@ -131,7 +131,7 @@ def get_config(cifar_classnum): ...@@ -131,7 +131,7 @@ def get_config(cifar_classnum):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.', required=True) parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--classnum', help='10 for cifar10 or 100 for cifar100', parser.add_argument('--classnum', help='10 for cifar10 or 100 for cifar100',
type=int, default=10) type=int, default=10)
...@@ -147,6 +147,6 @@ if __name__ == '__main__': ...@@ -147,6 +147,6 @@ if __name__ == '__main__':
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
num_gpu = get_num_gpu() num_gpu = get_num_gpu()
trainer = QueueInputTrainer() if num_gpu <= 1 \ trainer = SimpleTrainer() if num_gpu <= 1 \
else SyncMultiGPUTrainerParameterServer(num_gpu) else SyncMultiGPUTrainerParameterServer(num_gpu)
launch_train_with_config(config, trainer) launch_train_with_config(config, trainer)
...@@ -18,9 +18,10 @@ from ..dataflow import DataFlow, MapData, RepeatedData, DataFlowTerminated ...@@ -18,9 +18,10 @@ from ..dataflow import DataFlow, MapData, RepeatedData, DataFlowTerminated
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..tfutils.dependency import dependency_of_fetches
from ..utils import logger from ..utils import logger
from ..utils.concurrency import ShareSessionThread from ..utils.concurrency import ShareSessionThread
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated, deprecated
from ..callbacks.base import Callback, CallbackFactory from ..callbacks.base import Callback, CallbackFactory
from ..callbacks.graph import RunOp from ..callbacks.graph import RunOp
...@@ -117,7 +118,8 @@ class EnqueueThread(ShareSessionThread): ...@@ -117,7 +118,8 @@ class EnqueueThread(ShareSessionThread):
self.op = self.queue.enqueue(self.placehdrs) self.op = self.queue.enqueue(self.placehdrs)
self.close_op = self.queue.close(cancel_pending_enqueues=True) self.close_op = self.queue.close(cancel_pending_enqueues=True)
self._lock = threading.Lock() self._running = threading.Event()
self._running.set()
# self._size = queue.size() # self._size = queue.size()
def run(self): def run(self):
...@@ -126,8 +128,8 @@ class EnqueueThread(ShareSessionThread): ...@@ -126,8 +128,8 @@ class EnqueueThread(ShareSessionThread):
self.reinitialize_dataflow() self.reinitialize_dataflow()
while True: while True:
# pausable loop # pausable loop
self._lock.acquire() if not self._running.is_set():
self._lock.release() self._running.wait()
dp = next(self._itr) dp = next(self._itr)
feed = dict(zip(self.placehdrs, dp)) feed = dict(zip(self.placehdrs, dp))
...@@ -151,10 +153,10 @@ class EnqueueThread(ShareSessionThread): ...@@ -151,10 +153,10 @@ class EnqueueThread(ShareSessionThread):
self._itr = self.dataflow.get_data() self._itr = self.dataflow.get_data()
def pause(self): def pause(self):
self._lock.acquire() self._running.clear()
def resume(self): def resume(self):
self._lock.release() self._running.set()
class QueueInput(FeedfreeInput): class QueueInput(FeedfreeInput):
...@@ -486,7 +488,7 @@ class StagingInput(FeedfreeInput): ...@@ -486,7 +488,7 @@ class StagingInput(FeedfreeInput):
it requires that all outputs ever produced by this InputSource will be fetched together. it requires that all outputs ever produced by this InputSource will be fetched together.
This means that in multi-GPU training, you should ensure that each call on `hooked_sess.run` This means that in multi-GPU training, you should ensure that each call on `hooked_sess.run`
depends on all input tensors on all GPUs. depends on either all input tensors on all GPUs, or no input tensors at all.
As a result you cannot use this InputSource for :class:`InferenceRunner`. As a result you cannot use this InputSource for :class:`InferenceRunner`.
""" """
class StagingCallback(Callback): class StagingCallback(Callback):
...@@ -503,6 +505,7 @@ class StagingInput(FeedfreeInput): ...@@ -503,6 +505,7 @@ class StagingInput(FeedfreeInput):
self.stage_op = self._input._get_stage_op() self.stage_op = self._input._get_stage_op()
unstage_ops = self._input._get_unstage_ops() unstage_ops = self._input._get_unstage_ops()
unstage_op = tf.group(unstage_ops, name='unstage_all') unstage_op = tf.group(unstage_ops, name='unstage_all')
self._check_dependency_op = unstage_ops[0]
self.fetches = tf.train.SessionRunArgs( self.fetches = tf.train.SessionRunArgs(
fetches=[self.stage_op, unstage_op]) fetches=[self.stage_op, unstage_op])
...@@ -510,8 +513,8 @@ class StagingInput(FeedfreeInput): ...@@ -510,8 +513,8 @@ class StagingInput(FeedfreeInput):
logger.info("Pre-filling StagingArea ...") logger.info("Pre-filling StagingArea ...")
for k in range(self.nr_stage): for k in range(self.nr_stage):
self.stage_op.run() self.stage_op.run()
logger.info("Successfully put {} element{} to StagingArea.".format( logger.info("{} element{} put into StagingArea.".format(
self.nr_stage, "s" if self.nr_stage > 1 else "")) self.nr_stage, "s were" if self.nr_stage > 1 else " was"))
def _before_run(self, ctx): def _before_run(self, ctx):
# This has to happen once, right before the first iteration. # This has to happen once, right before the first iteration.
...@@ -519,6 +522,9 @@ class StagingInput(FeedfreeInput): ...@@ -519,6 +522,9 @@ class StagingInput(FeedfreeInput):
if not self._initialized: if not self._initialized:
self._initialized = True self._initialized = True
self._prefill() self._prefill()
# Only step the stagingarea when the input is evaluated in this sess.run
fetches = ctx.original_args.fetches
if dependency_of_fetches(fetches, self._check_dependency_op):
return self.fetches return self.fetches
def __init__(self, input, towers=None, nr_stage=1, device=None): def __init__(self, input, towers=None, nr_stage=1, device=None):
...@@ -568,9 +574,10 @@ class StagingInput(FeedfreeInput): ...@@ -568,9 +574,10 @@ class StagingInput(FeedfreeInput):
yield yield
def _get_input_tensors(self): def _get_input_tensors(self):
with self.cached_name_scope(), self._device_ctx():
inputs = self._input.get_input_tensors() inputs = self._input.get_input_tensors()
with self._device_ctx():
with self.cached_name_scope():
# Putting variables to stagingarea will cause trouble # Putting variables to stagingarea will cause trouble
dtypes = [] dtypes = []
for idx in range(len(inputs)): for idx in range(len(inputs)):
...@@ -589,6 +596,7 @@ class StagingInput(FeedfreeInput): ...@@ -589,6 +596,7 @@ class StagingInput(FeedfreeInput):
outputs = stage.get() outputs = stage.get()
if isinstance(outputs, tf.Tensor): # when size=1, TF doesn't return a list if isinstance(outputs, tf.Tensor): # when size=1, TF doesn't return a list
outputs = [outputs] outputs = [outputs]
for vin, vout in zip(inputs, outputs): for vin, vout in zip(inputs, outputs):
vout.set_shape(vin.get_shape()) vout.set_shape(vin.get_shape())
self._unstage_ops.append(outputs) self._unstage_ops.append(outputs)
...@@ -617,4 +625,6 @@ class StagingInput(FeedfreeInput): ...@@ -617,4 +625,6 @@ class StagingInput(FeedfreeInput):
run_step=True) run_step=True)
StagingInputWrapper = StagingInput @deprecated("Renamed to StagingInput", "2018-08-01")
def StagingInputWrapper(*args, **kwargs):
return StagingInput(*args, **kwargs)
...@@ -7,6 +7,10 @@ from ..utils.argtools import graph_memoized ...@@ -7,6 +7,10 @@ from ..utils.argtools import graph_memoized
Utils about parsing dependencies in the graph. Utils about parsing dependencies in the graph.
""" """
__all__ = [
'dependency_of_targets', 'dependency_of_fetches'
]
@graph_memoized @graph_memoized
def dependency_of_targets(targets, op): def dependency_of_targets(targets, op):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment