Commit 2b41edf7 authored by Yuxin Wu's avatar Yuxin Wu

bugfix

parent 897d29e3
...@@ -7,7 +7,6 @@ from abc import ABCMeta ...@@ -7,7 +7,6 @@ from abc import ABCMeta
import six import six
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..tfutils.common import get_op_or_tensor_by_name from ..tfutils.common import get_op_or_tensor_by_name
from ..train.tower import TowerTrainer
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory'] __all__ = ['Callback', 'ProxyCallback', 'CallbackFactory']
...@@ -212,6 +211,8 @@ class Callback(object): ...@@ -212,6 +211,8 @@ class Callback(object):
Will automatically check for the __first training tower__ Will automatically check for the __first training tower__
if no tensor with the given name exists. if no tensor with the given name exists.
""" """
from ..train.tower import TowerTrainer # noqa
def get_tensor(name): def get_tensor(name):
msg = "Tensor {} not found in the graph!".format(name) msg = "Tensor {} not found in the graph!".format(name)
try: try:
...@@ -221,7 +222,7 @@ class Callback(object): ...@@ -221,7 +222,7 @@ class Callback(object):
assert isinstance(self.trainer, TowerTrainer), msg assert isinstance(self.trainer, TowerTrainer), msg
towers = self.trainer.tower_func.towers towers = self.trainer.tower_func.towers
try: try:
return towers.training()[name] return towers.training()[0][name]
except KeyError: except KeyError:
raise KeyError(msg) raise KeyError(msg)
return [get_tensor(name) for name in names] return [get_tensor(name) for name in names]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment