Commit 897d29e3 authored by Yuxin Wu's avatar Yuxin Wu

callback.get_tensor_maybe_in_tower

parent 395786db
......@@ -7,6 +7,7 @@ from abc import ABCMeta
import six
from ..utils.develop import log_deprecated
from ..tfutils.common import get_op_or_tensor_by_name
from ..train.tower import TowerTrainer
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory']
......@@ -205,6 +206,26 @@ class Callback(object):
def __str__(self):
return type(self).__name__
def get_tensors_maybe_in_tower(self, names):
"""
Get tensors in the graph.
Will automatically check for the __first training tower__
if no tensor with the given name exists.
"""
def get_tensor(name):
msg = "Tensor {} not found in the graph!".format(name)
try:
return get_op_or_tensor_by_name(name)
except KeyError:
pass
assert isinstance(self.trainer, TowerTrainer), msg
towers = self.trainer.tower_func.towers
try:
return towers.training()[name]
except KeyError:
raise KeyError(msg)
return [get_tensor(name) for name in names]
class ProxyCallback(Callback):
""" A callback which proxy all methods to another callback.
......
......@@ -12,7 +12,7 @@ from ..utils import logger
from ..utils.utils import get_tqdm_kwargs
from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
from ..tfutils.common import (
get_op_tensor_name, get_op_or_tensor_by_name, get_global_step_var)
get_op_tensor_name, get_global_step_var)
from .base import Callback
__all__ = ['TensorPrinter', 'StepTensorPrinter', 'ProgressBar']
......@@ -33,7 +33,7 @@ class TensorPrinter(Callback):
self._names = names
def _setup_graph(self):
self._fetches = get_op_or_tensor_by_name(self._names)
self._fetches = self.get_tensors_maybe_in_tower(self._names)
def _before_run(self, _):
return self._fetches
......@@ -70,7 +70,7 @@ class ProgressBar(Callback):
self._total = self.trainer.steps_per_epoch
self._tqdm_args = get_tqdm_kwargs(leave=True)
self._fetches = get_op_or_tensor_by_name(self._names) or None
self._fetches = self.get_tensors_maybe_in_tower(self._names) or None
if self._fetches:
self._fetches = tf.train.SessionRunArgs(self._fetches)
self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} "
......
......@@ -226,6 +226,14 @@ class TowerTensorHandles(object):
return self._handles[name_or_index]
return self._name_to_handle[name_or_index]
def training(self):
"""
Returns:
Still a :class:`TowerTensorHandles`, containing only the training towers.
"""
handles = [h for h in self._handles if h.is_training]
return TowerTensorHandles(handles)
class TowerTensorHandle(object):
"""
......@@ -315,3 +323,7 @@ class TowerTensorHandle(object):
The output returned by the tower function.
"""
return self._output
@property
def is_training(self):
return self._ctx.is_training
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment