Commit 6e5ed1f1 authored by Yuxin Wu's avatar Yuxin Wu

add dependency parsing

parent 3e652668
...@@ -53,6 +53,9 @@ before_script: ...@@ -53,6 +53,9 @@ before_script:
- protoc --version - protoc --version
- python -c "import cv2; print('OpenCV '+ cv2.__version__)" - python -c "import cv2; print('OpenCV '+ cv2.__version__)"
- python -c "import tensorflow as tf; print('TensorFlow '+ tf.__version__)" - python -c "import tensorflow as tf; print('TensorFlow '+ tf.__version__)"
# Check that these private names can be imported because tensorpack is using them
- python -c "from tensorflow.python.client.session import _FetchHandler"
- python -c "from tensorflow.python.training.monitored_session import _HookedSession"
script: script:
- flake8 . - flake8 .
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
## Build the docs: ## Build the docs:
### Dependencies: ### Dependencies:
1. Python3 1. Python 3
2. `pip install -r requirements.txt`. These requirements are different from tensorpack dependencies. 2. Remove "tensorflow" from `requirements.txt` since you probably prefer to install TensorFlow by yourself.
3. `pip install -r requirements.txt`. Note that these requirements are different from tensorpack dependencies.
### Build HTML docs: ### Build HTML docs:
`make html` `make html`
......
...@@ -480,6 +480,14 @@ class StagingInput(FeedfreeInput): ...@@ -480,6 +480,14 @@ class StagingInput(FeedfreeInput):
""" """
A wrapper around a feedfree input, A wrapper around a feedfree input,
to prefetch the input in StagingArea (on GPUs). to prefetch the input in StagingArea (on GPUs).
It works by registering hooks to put & get tensors into the StagingArea.
If `get_input_tensors` gets called multiple times,
it requires that all outputs ever produced by this InputSource will be fetched together.
This means that in multi-GPU training, you should ensure that each call on `hooked_sess.run`
depends on all input tensors on all GPUs.
As a result you cannot use this InputSource for :class:`InferenceRunner`.
""" """
class StagingCallback(Callback): class StagingCallback(Callback):
""" """
...@@ -493,7 +501,8 @@ class StagingInput(FeedfreeInput): ...@@ -493,7 +501,8 @@ class StagingInput(FeedfreeInput):
def _setup_graph(self): def _setup_graph(self):
self.stage_op = self._input._get_stage_op() self.stage_op = self._input._get_stage_op()
unstage_op = self._input._get_unstage_op() unstage_ops = self._input._get_unstage_ops()
unstage_op = tf.group(unstage_ops, name='unstage_all')
self.fetches = tf.train.SessionRunArgs( self.fetches = tf.train.SessionRunArgs(
fetches=[self.stage_op, unstage_op]) fetches=[self.stage_op, unstage_op])
...@@ -506,6 +515,7 @@ class StagingInput(FeedfreeInput): ...@@ -506,6 +515,7 @@ class StagingInput(FeedfreeInput):
def _before_run(self, ctx): def _before_run(self, ctx):
# This has to happen once, right before the first iteration. # This has to happen once, right before the first iteration.
# doing it in `before_train` may not work because QueueInput happens in before_train.
if not self._initialized: if not self._initialized:
self._initialized = True self._initialized = True
self._prefill() self._prefill()
...@@ -589,10 +599,10 @@ class StagingInput(FeedfreeInput): ...@@ -589,10 +599,10 @@ class StagingInput(FeedfreeInput):
with self.cached_name_scope(): with self.cached_name_scope():
return tf.group(*self._stage_ops) return tf.group(*self._stage_ops)
def _get_unstage_op(self): def _get_unstage_ops(self):
with self.cached_name_scope(): with self.cached_name_scope():
all_outputs = list(chain.from_iterable(self._unstage_ops)) all_outputs = list(chain.from_iterable(self._unstage_ops))
return tf.group(*all_outputs) return all_outputs
# for debugging only # for debugging only
def _create_ema_callback(self): def _create_ema_callback(self):
......
import tensorflow as tf
from tensorflow.contrib.graph_editor import get_backward_walk_ops
from ..utils.argtools import graph_memoized
"""
Utils about parsing dependencies in the graph.
"""
@graph_memoized
def dependency_of_targets(targets, op):
"""
Check that op is in the subgraph induced by the dependencies of targets.
The result is memoized.
This is useful if some SessionRunHooks should be run only together with certain ops.
Args:
targets: a tuple of ops or tensors. The targets to find dependencies of.
op (tf.Operation or tf.Tensor):
Returns:
bool
"""
# TODO tensorarray? sparsetensor?
if isinstance(op, tf.Tensor):
op = op.op
assert isinstance(op, tf.Operation), op
# alternative implementation can use graph_util.extract_sub_graph
dependent_ops = get_backward_walk_ops(targets, control_inputs=True)
return op in dependent_ops
def dependency_of_fetches(fetches, op):
"""
Check that op is in the subgraph induced by the dependencies of fetches.
fetches may have more general structure.
Args:
fetches: An argument to `sess.run`. Nested structure will affect performance.
op (tf.Operation or tf.Tensor):
Returns:
bool
"""
try:
from tensorflow.python.client.session import _FetchHandler as FetchHandler
handler = FetchHandler(tf.get_default_graph(), fetches, {})
targets = tuple(handler.fetches() + handler.targets())
except ImportError:
if isinstance(fetches, list):
targets = tuple(fetches)
elif isinstance(fetches, dict):
raise ValueError("Don't know how to parse dictionary to fetch list! "
"This is a bug of tensorpack.")
else:
targets = (fetches, )
return dependency_of_targets(targets, op)
if __name__ == '__main__':
a = tf.random_normal(shape=[3, 3])
b = tf.random_normal(shape=[3, 3])
print(dependency_of_fetches(a, a))
print(dependency_of_fetches([a, b], a))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment