Commit 1b8fc8f2 authored by Yuxin Wu's avatar Yuxin Wu

Raise error if two StagingInput are nested (#899)

parent 0dbcbac7
......@@ -29,7 +29,7 @@ __all__ = ['PlaceholderInput', 'FeedInput', 'FeedfreeInput',
'QueueInput', 'BatchQueueInput',
'DummyConstantInput', 'TensorInput',
'ZMQInput', 'TFDatasetInput',
'StagingInputWrapper', 'StagingInput']
'StagingInput']
def _get_reset_callback(df):
......@@ -527,6 +527,8 @@ class StagingInput(FeedfreeInput):
This means that in multi-GPU training, you should ensure that each call on `hooked_sess.run`
depends on either all input tensors on all GPUs, or no input tensors at all.
As a result you cannot use this InputSource for :class:`InferenceRunner`.
More than one StagingInput cannot be used together.
"""
class StagingCallback(Callback):
"""
......@@ -562,6 +564,7 @@ class StagingInput(FeedfreeInput):
# Only step the stagingarea when the input is evaluated in this sess.run
fetches = ctx.original_args.fetches
if dependency_of_fetches(fetches, self._check_dependency_op):
# note: this disable nesting of StagingInput
return self.fetches
def __init__(self, input, nr_stage=1, device=None):
......@@ -576,6 +579,9 @@ class StagingInput(FeedfreeInput):
"""
if not isinstance(input, FeedfreeInput):
raise ValueError("StagingInput takes a FeedfreeInput! Got {}".format(input))
if isinstance(input, StagingInput):
raise ValueError("StagingInput cannot be nested!")
self._input = input
self._nr_stage = nr_stage
......@@ -660,8 +666,3 @@ class StagingInput(FeedfreeInput):
run_before=False,
run_as_trigger=False,
run_step=True)
@deprecated("Renamed to StagingInput", "2018-08-01")
def StagingInputWrapper(*args, **kwargs):
return StagingInput(*args, **kwargs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment