Commit d89b6f07 authored by Yuxin Wu's avatar Yuxin Wu

use TowerHandle.get_tensor to access variables (fix #1409)

parent 6f0ba596
......@@ -12,6 +12,7 @@ from scipy.signal import convolve2d
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.utils import logger
from tensorpack.utils.gpu import change_gpu
from tensorpack.utils.argtools import shape2d, shape4d
from tensorpack.utils.viz import *
......
......@@ -385,9 +385,9 @@ class TowerTensorHandle(object):
def get_tensor(self, name):
"""
Get a tensor in this tower. The name can be:
Get a tensor in this tower. The name argument can be:
1. The name of the tensor without any tower prefix.
1. The name of a tensor/variable without any tower prefix.
2. A name in the input signature, if it is used when building the tower.
......@@ -405,7 +405,6 @@ class TowerTensorHandle(object):
except KeyError:
if name in self._extra_tensor_names:
return self._extra_tensor_names[name]
raise
else:
if name in self._extra_tensor_names:
mapped_tensor = self._extra_tensor_names[name]
......@@ -415,6 +414,8 @@ class TowerTensorHandle(object):
" Assuming it is the input '{}'.".format(mapped_tensor.name))
return mapped_tensor
return ret
# should also allow variables in get_tensor
return self.get_variable(name)
def get_tensors(self, names):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment