Commit d89b6f07 authored by Yuxin Wu's avatar Yuxin Wu

use TowerHandle.get_tensor to access variables (fix #1409)

parent 6f0ba596
...@@ -12,6 +12,7 @@ from scipy.signal import convolve2d ...@@ -12,6 +12,7 @@ from scipy.signal import convolve2d
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.gpu import change_gpu
from tensorpack.utils.argtools import shape2d, shape4d from tensorpack.utils.argtools import shape2d, shape4d
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
......
...@@ -385,9 +385,9 @@ class TowerTensorHandle(object): ...@@ -385,9 +385,9 @@ class TowerTensorHandle(object):
def get_tensor(self, name): def get_tensor(self, name):
""" """
Get a tensor in this tower. The name can be: Get a tensor in this tower. The name argument can be:
1. The name of the tensor without any tower prefix. 1. The name of a tensor/variable without any tower prefix.
2. A name in the input signature, if it is used when building the tower. 2. A name in the input signature, if it is used when building the tower.
...@@ -405,7 +405,6 @@ class TowerTensorHandle(object): ...@@ -405,7 +405,6 @@ class TowerTensorHandle(object):
except KeyError: except KeyError:
if name in self._extra_tensor_names: if name in self._extra_tensor_names:
return self._extra_tensor_names[name] return self._extra_tensor_names[name]
raise
else: else:
if name in self._extra_tensor_names: if name in self._extra_tensor_names:
mapped_tensor = self._extra_tensor_names[name] mapped_tensor = self._extra_tensor_names[name]
...@@ -415,6 +414,8 @@ class TowerTensorHandle(object): ...@@ -415,6 +414,8 @@ class TowerTensorHandle(object):
" Assuming it is the input '{}'.".format(mapped_tensor.name)) " Assuming it is the input '{}'.".format(mapped_tensor.name))
return mapped_tensor return mapped_tensor
return ret return ret
# should also allow variables in get_tensor
return self.get_variable(name)
def get_tensors(self, names): def get_tensors(self, names):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment