Commit 5cccf2b8 authored by Yuxin Wu's avatar Yuxin Wu

compatible with old version

parent dc378b53
...@@ -18,7 +18,7 @@ class ModelSaver(Callback): ...@@ -18,7 +18,7 @@ class ModelSaver(Callback):
Save the model to logger directory. Save the model to logger directory.
""" """
def __init__(self, keep_recent=10, keep_freq=0.5, def __init__(self, keep_recent=10, keep_freq=0.5,
var_collections=tf.GraphKeys.GLOBAL_VARIABLES): var_collections=tf.GraphKeys().VARIABLES):
""" """
:param keep_recent: see `tf.train.Saver` documentation. :param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation. :param keep_freq: see `tf.train.Saver` documentation.
......
...@@ -93,7 +93,7 @@ class ModelFromMetaGraph(ModelDesc): ...@@ -93,7 +93,7 @@ class ModelFromMetaGraph(ModelDesc):
tf.train.import_meta_graph(filename) tf.train.import_meta_graph(filename)
all_coll = tf.get_default_graph().get_all_collection_keys() all_coll = tf.get_default_graph().get_all_collection_keys()
for k in [INPUT_VARS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES, for k in [INPUT_VARS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES,
tf.GraphKeys.GLOBAL_VARIABLES]: tf.GraphKeys().VARIABLES]:
assert k in all_coll, \ assert k in all_coll, \
"Collection {} not found in metagraph!".format(k) "Collection {} not found in metagraph!".format(k)
......
...@@ -113,7 +113,10 @@ class SaverRestore(SessionInit): ...@@ -113,7 +113,10 @@ class SaverRestore(SessionInit):
:param vars_available: varaible names available in the checkpoint, for existence checking :param vars_available: varaible names available in the checkpoint, for existence checking
:returns: a dict of {var_name: [var, var]} to restore :returns: a dict of {var_name: [var, var]} to restore
""" """
vars_to_restore = tf.global_variables() try:
vars_to_restore = tf.global_variables()
except AttributeError:
vars_to_restore = tf.all_variables()
var_dict = defaultdict(list) var_dict = defaultdict(list)
chkpt_vars_used = set() chkpt_vars_used = set()
for v in vars_to_restore: for v in vars_to_restore:
...@@ -150,7 +153,7 @@ class ParamRestore(SessionInit): ...@@ -150,7 +153,7 @@ class ParamRestore(SessionInit):
self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)} self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)}
def _init(self, sess): def _init(self, sess):
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) variables = tf.get_collection(tf.GraphKeys().VARIABLES) # TODO
variable_names = set([get_savename_from_varname(k.name) for k in variables]) variable_names = set([get_savename_from_varname(k.name) for k in variables])
param_names = set(six.iterkeys(self.prms)) param_names = set(six.iterkeys(self.prms))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment