Commit 65e1fa46 authored by Yuxin Wu's avatar Yuxin Wu

update docs about tensor access

parent 8915849e
...@@ -87,7 +87,16 @@ to let this method run every k steps or every k epochs. ...@@ -87,7 +87,16 @@ to let this method run every k steps or every k epochs.
### What you can do in the callback ### What you can do in the callback
* Access tensors / ops in either training / inference mode (need to create them in `_setup_graph`). * Access tensors / ops in either training / inference mode (need to create them in `_setup_graph`).
`self.trainer.get_predictor` is a helper function to create a callable under inference mode. * Use TF methods such as `self.graph.get_tensor_by_name`, to access tensors.
If you're using a `TowerTrainer` instance, more tools are available:
* Use `self.trainer.tower_func.towers` to access the
[tower handles](http://tensorpack.readthedocs.io/en/latest/modules/tfutils.html#tensorpack.tfutils.tower.TowerTensorHandles),
and therefore the tensors in each tower.
* [self.get_tensors_maybe_in_tower()](http://tensorpack.readthedocs.io/en/latest/modules/callbacks.html#tensorpack.callbacks.Callback.get_tensors_maybe_in_tower)
is a helper function to access tensors in the first training tower.
* [self.trainer.get_predictor()](http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.TowerTrainer.get_predictor)
is a helper function to create a callable under inference mode.
* Write stuff to the monitor backend, by `self.trainer.monitors.put_xxx`. * Write stuff to the monitor backend, by `self.trainer.monitors.put_xxx`.
The monitors might direct your events to TensorFlow events file, JSON file, stdout, etc. The monitors might direct your events to TensorFlow events file, JSON file, stdout, etc.
You can get history monitor data as well. See the docs for [Monitors](../../modules/callbacks.html#tensorpack.callbacks.Monitors) You can get history monitor data as well. See the docs for [Monitors](../../modules/callbacks.html#tensorpack.callbacks.Monitors)
......
...@@ -207,9 +207,12 @@ class Callback(object): ...@@ -207,9 +207,12 @@ class Callback(object):
def get_tensors_maybe_in_tower(self, names): def get_tensors_maybe_in_tower(self, names):
""" """
Get tensors in the graph. Get tensors in the graph with the given names.
Will automatically check for the __first training tower__ Will automatically check for the *first training tower*
if no tensor with the given name exists. if no existing tensor is found with the name.
Returns:
[tf.Tensor]
""" """
from ..train.tower import TowerTrainer # noqa from ..train.tower import TowerTrainer # noqa
......
...@@ -80,24 +80,6 @@ class TowerContext(object): ...@@ -80,24 +80,6 @@ class TowerContext(object):
def ns_name(self): def ns_name(self):
return self._name return self._name
# TODO another method to filter by ns_name
def filter_vars_by_vs_name(self, varlist):
"""
Filter the list and only keep those under the current variable scope.
If this tower doesn't contain its own variable scope, return the list as-is.
Args:
varlist (list[tf.Variable] or list[tf.Tensor]):
"""
if not self.has_own_variables:
return varlist
if len(self._vs_name) == 0:
# main_training_tower with no name. assume no other towers has
# been built yet, then varlist contains vars only in the first tower.
return varlist
prefix = self._vs_name + '/'
return [v for v in varlist if v.op.name.startswith(prefix)]
def get_collection_in_tower(self, key): def get_collection_in_tower(self, key):
""" """
Get items from this collection that are added in the current tower. Get items from this collection that are added in the current tower.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment