Commit ab86361f authored by Yuxin Wu's avatar Yuxin Wu

use "tensor" instead of "var"

parent 64f97425
...@@ -11,7 +11,7 @@ import os ...@@ -11,7 +11,7 @@ import os
from .base import Callback from .base import Callback
from ..utils import logger from ..utils import logger
from ..tfutils import get_op_var_name from ..tfutils import get_op_tensor_name
__all__ = ['HyperParam', 'GraphVarParam', 'ObjAttrParam', __all__ = ['HyperParam', 'GraphVarParam', 'ObjAttrParam',
'HyperParamSetter', 'HumanHyperParamSetter', 'HyperParamSetter', 'HumanHyperParamSetter',
...@@ -62,7 +62,7 @@ class GraphVarParam(HyperParam): ...@@ -62,7 +62,7 @@ class GraphVarParam(HyperParam):
""" """
self.name = name self.name = name
self.shape = shape self.shape = shape
self._readable_name, self.var_name = get_op_var_name(name) self._readable_name, self.var_name = get_op_tensor_name(name)
def setup_graph(self): def setup_graph(self):
""" Will setup the assign operator for that variable. """ """ Will setup the assign operator for that variable. """
......
...@@ -17,8 +17,8 @@ def _global_import(name): ...@@ -17,8 +17,8 @@ def _global_import(name):
_TO_IMPORT = set([ _TO_IMPORT = set([
'sessinit',
'common', 'common',
'sessinit',
'gradproc', 'gradproc',
'argscope', 'argscope',
'tower' 'tower'
......
...@@ -12,10 +12,9 @@ from contextlib import contextmanager ...@@ -12,10 +12,9 @@ from contextlib import contextmanager
__all__ = ['get_default_sess_config', __all__ = ['get_default_sess_config',
'get_global_step', 'get_global_step',
'get_global_step_var', 'get_global_step_var',
'get_op_var_name',
'get_op_tensor_name', 'get_op_tensor_name',
'get_vars_by_names',
'get_tensors_by_names', 'get_tensors_by_names',
'get_op_or_tensor_by_name',
'backup_collection', 'backup_collection',
'restore_collection', 'restore_collection',
'clear_collection', 'clear_collection',
...@@ -87,9 +86,6 @@ def get_op_tensor_name(name): ...@@ -87,9 +86,6 @@ def get_op_tensor_name(name):
return name, name + ':0' return name, name + ':0'
get_op_var_name = get_op_tensor_name
def get_tensors_by_names(names): def get_tensors_by_names(names):
""" """
Get a list of tensors in the default graph by a list of names. Get a list of tensors in the default graph by a list of names.
...@@ -100,12 +96,17 @@ def get_tensors_by_names(names): ...@@ -100,12 +96,17 @@ def get_tensors_by_names(names):
ret = [] ret = []
G = tf.get_default_graph() G = tf.get_default_graph()
for n in names: for n in names:
opn, varn = get_op_var_name(n) opn, varn = get_op_tensor_name(n)
ret.append(G.get_tensor_by_name(varn)) ret.append(G.get_tensor_by_name(varn))
return ret return ret
get_vars_by_names = get_tensors_by_names def get_op_or_tensor_by_name(name):
G = tf.get_default_graph()
if len(name) >= 3 and name[-2] == ':':
return G.get_tensor_by_name(name)
else:
return G.get_operation_by_name(name)
def backup_collection(keys): def backup_collection(keys):
......
...@@ -10,7 +10,7 @@ import tensorflow as tf ...@@ -10,7 +10,7 @@ import tensorflow as tf
import six import six
from ..utils import logger, PREDICT_TOWER from ..utils import logger, PREDICT_TOWER
from .common import get_op_var_name from .common import get_op_tensor_name
from .varmanip import (SessionUpdate, get_savename_from_varname, from .varmanip import (SessionUpdate, get_savename_from_varname,
is_training_name, get_checkpoint_path) is_training_name, get_checkpoint_path)
...@@ -149,7 +149,7 @@ class ParamRestore(SessionInit): ...@@ -149,7 +149,7 @@ class ParamRestore(SessionInit):
param_dict (dict): a dict of {name: value} param_dict (dict): a dict of {name: value}
""" """
# use varname (with :0) for consistency # use varname (with :0) for consistency
self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)} self.prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(param_dict)}
def _init(self, sess): def _init(self, sess):
variables = tf.get_collection(tf.GraphKeys().VARIABLES) # TODO variables = tf.get_collection(tf.GraphKeys().VARIABLES) # TODO
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment