Commit 5cccf2b8 authored by Yuxin Wu's avatar Yuxin Wu

compatible with old version

parent dc378b53
......@@ -18,7 +18,7 @@ class ModelSaver(Callback):
Save the model to logger directory.
"""
def __init__(self, keep_recent=10, keep_freq=0.5,
var_collections=tf.GraphKeys.GLOBAL_VARIABLES):
var_collections=tf.GraphKeys().VARIABLES):
"""
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
......
......@@ -93,7 +93,7 @@ class ModelFromMetaGraph(ModelDesc):
tf.train.import_meta_graph(filename)
all_coll = tf.get_default_graph().get_all_collection_keys()
for k in [INPUT_VARS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES,
tf.GraphKeys.GLOBAL_VARIABLES]:
tf.GraphKeys().VARIABLES]:
assert k in all_coll, \
"Collection {} not found in metagraph!".format(k)
......
......@@ -113,7 +113,10 @@ class SaverRestore(SessionInit):
:param vars_available: varaible names available in the checkpoint, for existence checking
:returns: a dict of {var_name: [var, var]} to restore
"""
vars_to_restore = tf.global_variables()
try:
vars_to_restore = tf.global_variables()
except AttributeError:
vars_to_restore = tf.all_variables()
var_dict = defaultdict(list)
chkpt_vars_used = set()
for v in vars_to_restore:
......@@ -150,7 +153,7 @@ class ParamRestore(SessionInit):
self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)}
def _init(self, sess):
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
variables = tf.get_collection(tf.GraphKeys().VARIABLES) # TODO
variable_names = set([get_savename_from_varname(k.name) for k in variables])
param_names = set(six.iterkeys(self.prms))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment