Commit 2b41edf7 authored by Yuxin Wu's avatar Yuxin Wu

bugfix

parent 897d29e3
......@@ -7,7 +7,6 @@ from abc import ABCMeta
import six
from ..utils.develop import log_deprecated
from ..tfutils.common import get_op_or_tensor_by_name
from ..train.tower import TowerTrainer
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory']
......@@ -212,6 +211,8 @@ class Callback(object):
Will automatically check for the __first training tower__
if no tensor with the given name exists.
"""
from ..train.tower import TowerTrainer # noqa
def get_tensor(name):
msg = "Tensor {} not found in the graph!".format(name)
try:
......@@ -221,7 +222,7 @@ class Callback(object):
assert isinstance(self.trainer, TowerTrainer), msg
towers = self.trainer.tower_func.towers
try:
return towers.training()[name]
return towers.training()[0][name]
except KeyError:
raise KeyError(msg)
return [get_tensor(name) for name in names]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment