Commit 897d29e3 authored by Yuxin Wu's avatar Yuxin Wu

callback.get_tensor_maybe_in_tower

parent 395786db
...@@ -7,6 +7,7 @@ from abc import ABCMeta ...@@ -7,6 +7,7 @@ from abc import ABCMeta
import six import six
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..tfutils.common import get_op_or_tensor_by_name from ..tfutils.common import get_op_or_tensor_by_name
from ..train.tower import TowerTrainer
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory'] __all__ = ['Callback', 'ProxyCallback', 'CallbackFactory']
...@@ -205,6 +206,26 @@ class Callback(object): ...@@ -205,6 +206,26 @@ class Callback(object):
def __str__(self): def __str__(self):
return type(self).__name__ return type(self).__name__
def get_tensors_maybe_in_tower(self, names):
"""
Get tensors in the graph.
Will automatically check for the __first training tower__
if no tensor with the given name exists.
"""
def get_tensor(name):
msg = "Tensor {} not found in the graph!".format(name)
try:
return get_op_or_tensor_by_name(name)
except KeyError:
pass
assert isinstance(self.trainer, TowerTrainer), msg
towers = self.trainer.tower_func.towers
try:
return towers.training()[name]
except KeyError:
raise KeyError(msg)
return [get_tensor(name) for name in names]
class ProxyCallback(Callback): class ProxyCallback(Callback):
""" A callback which proxy all methods to another callback. """ A callback which proxy all methods to another callback.
......
...@@ -12,7 +12,7 @@ from ..utils import logger ...@@ -12,7 +12,7 @@ from ..utils import logger
from ..utils.utils import get_tqdm_kwargs from ..utils.utils import get_tqdm_kwargs
from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
from ..tfutils.common import ( from ..tfutils.common import (
get_op_tensor_name, get_op_or_tensor_by_name, get_global_step_var) get_op_tensor_name, get_global_step_var)
from .base import Callback from .base import Callback
__all__ = ['TensorPrinter', 'StepTensorPrinter', 'ProgressBar'] __all__ = ['TensorPrinter', 'StepTensorPrinter', 'ProgressBar']
...@@ -33,7 +33,7 @@ class TensorPrinter(Callback): ...@@ -33,7 +33,7 @@ class TensorPrinter(Callback):
self._names = names self._names = names
def _setup_graph(self): def _setup_graph(self):
self._fetches = get_op_or_tensor_by_name(self._names) self._fetches = self.get_tensors_maybe_in_tower(self._names)
def _before_run(self, _): def _before_run(self, _):
return self._fetches return self._fetches
...@@ -70,7 +70,7 @@ class ProgressBar(Callback): ...@@ -70,7 +70,7 @@ class ProgressBar(Callback):
self._total = self.trainer.steps_per_epoch self._total = self.trainer.steps_per_epoch
self._tqdm_args = get_tqdm_kwargs(leave=True) self._tqdm_args = get_tqdm_kwargs(leave=True)
self._fetches = get_op_or_tensor_by_name(self._names) or None self._fetches = self.get_tensors_maybe_in_tower(self._names) or None
if self._fetches: if self._fetches:
self._fetches = tf.train.SessionRunArgs(self._fetches) self._fetches = tf.train.SessionRunArgs(self._fetches)
self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} " self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} "
......
...@@ -226,6 +226,14 @@ class TowerTensorHandles(object): ...@@ -226,6 +226,14 @@ class TowerTensorHandles(object):
return self._handles[name_or_index] return self._handles[name_or_index]
return self._name_to_handle[name_or_index] return self._name_to_handle[name_or_index]
def training(self):
"""
Returns:
Still a :class:`TowerTensorHandles`, containing only the training towers.
"""
handles = [h for h in self._handles if h.is_training]
return TowerTensorHandles(handles)
class TowerTensorHandle(object): class TowerTensorHandle(object):
""" """
...@@ -315,3 +323,7 @@ class TowerTensorHandle(object): ...@@ -315,3 +323,7 @@ class TowerTensorHandle(object):
The output returned by the tower function. The output returned by the tower function.
""" """
return self._output return self._output
@property
def is_training(self):
return self._ctx.is_training
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment