Commit ab86361f authored by Yuxin Wu's avatar Yuxin Wu

use "tensor" instead of "var"

parent 64f97425
......@@ -11,7 +11,7 @@ import os
from .base import Callback
from ..utils import logger
from ..tfutils import get_op_var_name
from ..tfutils import get_op_tensor_name
__all__ = ['HyperParam', 'GraphVarParam', 'ObjAttrParam',
'HyperParamSetter', 'HumanHyperParamSetter',
......@@ -62,7 +62,7 @@ class GraphVarParam(HyperParam):
"""
self.name = name
self.shape = shape
self._readable_name, self.var_name = get_op_var_name(name)
self._readable_name, self.var_name = get_op_tensor_name(name)
def setup_graph(self):
""" Will setup the assign operator for that variable. """
......
......@@ -17,8 +17,8 @@ def _global_import(name):
_TO_IMPORT = set([
'sessinit',
'common',
'sessinit',
'gradproc',
'argscope',
'tower'
......
......@@ -12,10 +12,9 @@ from contextlib import contextmanager
__all__ = ['get_default_sess_config',
'get_global_step',
'get_global_step_var',
'get_op_var_name',
'get_op_tensor_name',
'get_vars_by_names',
'get_tensors_by_names',
'get_op_or_tensor_by_name',
'backup_collection',
'restore_collection',
'clear_collection',
......@@ -87,9 +86,6 @@ def get_op_tensor_name(name):
return name, name + ':0'
get_op_var_name = get_op_tensor_name
def get_tensors_by_names(names):
"""
Get a list of tensors in the default graph by a list of names.
......@@ -100,12 +96,17 @@ def get_tensors_by_names(names):
ret = []
G = tf.get_default_graph()
for n in names:
opn, varn = get_op_var_name(n)
opn, varn = get_op_tensor_name(n)
ret.append(G.get_tensor_by_name(varn))
return ret
get_vars_by_names = get_tensors_by_names
def get_op_or_tensor_by_name(name):
G = tf.get_default_graph()
if len(name) >= 3 and name[-2] == ':':
return G.get_tensor_by_name(name)
else:
return G.get_operation_by_name(name)
def backup_collection(keys):
......
......@@ -10,7 +10,7 @@ import tensorflow as tf
import six
from ..utils import logger, PREDICT_TOWER
from .common import get_op_var_name
from .common import get_op_tensor_name
from .varmanip import (SessionUpdate, get_savename_from_varname,
is_training_name, get_checkpoint_path)
......@@ -149,7 +149,7 @@ class ParamRestore(SessionInit):
param_dict (dict): a dict of {name: value}
"""
# use varname (with :0) for consistency
self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)}
self.prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(param_dict)}
def _init(self, sess):
variables = tf.get_collection(tf.GraphKeys().VARIABLES) # TODO
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment