Commit 65e1fa46 authored by Yuxin Wu's avatar Yuxin Wu

update docs about tensor access

parent 8915849e
......@@ -87,7 +87,16 @@ to let this method run every k steps or every k epochs.
### What you can do in the callback
* Access tensors / ops in either training / inference mode (need to create them in `_setup_graph`).
`self.trainer.get_predictor` is a helper function to create a callable under inference mode.
* Use TF methods such as `self.graph.get_tensor_by_name`, to access tensors.
If you're using a `TowerTrainer` instance, more tools are available:
* Use `self.trainer.tower_func.towers` to access the
[tower handles](http://tensorpack.readthedocs.io/en/latest/modules/tfutils.html#tensorpack.tfutils.tower.TowerTensorHandles),
and therefore the tensors in each tower.
* [self.get_tensors_maybe_in_tower()](http://tensorpack.readthedocs.io/en/latest/modules/callbacks.html#tensorpack.callbacks.Callback.get_tensors_maybe_in_tower)
is a helper function to access tensors in the first training tower.
* [self.trainer.get_predictor()](http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.TowerTrainer.get_predictor)
is a helper function to create a callable under inference mode.
* Write stuff to the monitor backend, by `self.trainer.monitors.put_xxx`.
The monitors might direct your events to TensorFlow events file, JSON file, stdout, etc.
You can get history monitor data as well. See the docs for [Monitors](../../modules/callbacks.html#tensorpack.callbacks.Monitors)
......
......@@ -207,9 +207,12 @@ class Callback(object):
def get_tensors_maybe_in_tower(self, names):
"""
Get tensors in the graph.
Will automatically check for the __first training tower__
if no tensor with the given name exists.
Get tensors in the graph with the given names.
Will automatically check for the *first training tower*
if no existing tensor is found with the name.
Returns:
[tf.Tensor]
"""
from ..train.tower import TowerTrainer # noqa
......
......@@ -80,24 +80,6 @@ class TowerContext(object):
def ns_name(self):
return self._name
# TODO another method to filter by ns_name
def filter_vars_by_vs_name(self, varlist):
"""
Filter the list and only keep those under the current variable scope.
If this tower doesn't contain its own variable scope, return the list as-is.
Args:
varlist (list[tf.Variable] or list[tf.Tensor]):
"""
if not self.has_own_variables:
return varlist
if len(self._vs_name) == 0:
# main_training_tower with no name. assume no other towers has
# been built yet, then varlist contains vars only in the first tower.
return varlist
prefix = self._vs_name + '/'
return [v for v in varlist if v.op.name.startswith(prefix)]
def get_collection_in_tower(self, key):
"""
Get items from this collection that are added in the current tower.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment