Commit 1b8fc8f2 authored by Yuxin Wu's avatar Yuxin Wu

Raise error if two StagingInput are nested (#899)

parent 0dbcbac7
...@@ -29,7 +29,7 @@ __all__ = ['PlaceholderInput', 'FeedInput', 'FeedfreeInput', ...@@ -29,7 +29,7 @@ __all__ = ['PlaceholderInput', 'FeedInput', 'FeedfreeInput',
'QueueInput', 'BatchQueueInput', 'QueueInput', 'BatchQueueInput',
'DummyConstantInput', 'TensorInput', 'DummyConstantInput', 'TensorInput',
'ZMQInput', 'TFDatasetInput', 'ZMQInput', 'TFDatasetInput',
'StagingInputWrapper', 'StagingInput'] 'StagingInput']
def _get_reset_callback(df): def _get_reset_callback(df):
...@@ -527,6 +527,8 @@ class StagingInput(FeedfreeInput): ...@@ -527,6 +527,8 @@ class StagingInput(FeedfreeInput):
This means that in multi-GPU training, you should ensure that each call on `hooked_sess.run` This means that in multi-GPU training, you should ensure that each call on `hooked_sess.run`
depends on either all input tensors on all GPUs, or no input tensors at all. depends on either all input tensors on all GPUs, or no input tensors at all.
As a result you cannot use this InputSource for :class:`InferenceRunner`. As a result you cannot use this InputSource for :class:`InferenceRunner`.
More than one StagingInput cannot be used together.
""" """
class StagingCallback(Callback): class StagingCallback(Callback):
""" """
...@@ -562,6 +564,7 @@ class StagingInput(FeedfreeInput): ...@@ -562,6 +564,7 @@ class StagingInput(FeedfreeInput):
# Only step the stagingarea when the input is evaluated in this sess.run # Only step the stagingarea when the input is evaluated in this sess.run
fetches = ctx.original_args.fetches fetches = ctx.original_args.fetches
if dependency_of_fetches(fetches, self._check_dependency_op): if dependency_of_fetches(fetches, self._check_dependency_op):
# note: this disable nesting of StagingInput
return self.fetches return self.fetches
def __init__(self, input, nr_stage=1, device=None): def __init__(self, input, nr_stage=1, device=None):
...@@ -576,6 +579,9 @@ class StagingInput(FeedfreeInput): ...@@ -576,6 +579,9 @@ class StagingInput(FeedfreeInput):
""" """
if not isinstance(input, FeedfreeInput): if not isinstance(input, FeedfreeInput):
raise ValueError("StagingInput takes a FeedfreeInput! Got {}".format(input)) raise ValueError("StagingInput takes a FeedfreeInput! Got {}".format(input))
if isinstance(input, StagingInput):
raise ValueError("StagingInput cannot be nested!")
self._input = input self._input = input
self._nr_stage = nr_stage self._nr_stage = nr_stage
...@@ -660,8 +666,3 @@ class StagingInput(FeedfreeInput): ...@@ -660,8 +666,3 @@ class StagingInput(FeedfreeInput):
run_before=False, run_before=False,
run_as_trigger=False, run_as_trigger=False,
run_step=True) run_step=True)
@deprecated("Renamed to StagingInput", "2018-08-01")
def StagingInputWrapper(*args, **kwargs):
return StagingInput(*args, **kwargs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment