Commit 4e7977f5 authored by Yuxin Wu's avatar Yuxin Wu

remove all_variables

parent f1fdb42e
......@@ -43,12 +43,7 @@ class GraphVarParam(HyperParam):
self._readable_name, self.var_name = get_op_var_name(name)
def setup_graph(self):
try:
all_vars = tf.global_variables()
except:
# TODO
all_vars = tf.all_variables()
all_vars = tf.global_variables()
for v in all_vars:
if v.name == self.var_name:
self.var = v
......
......@@ -10,7 +10,7 @@ import unittest
class TestModel(unittest.TestCase):
def run_variable(self, var):
sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(tf.global_variables_initializer())
if isinstance(var, list):
return sess.run(var)
else:
......
......@@ -151,7 +151,7 @@ if __name__ == '__main__':
mapv = tf.Variable(mapping)
output = ImageSample('sample', [imv, mapv], borderMode='constant')
sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(tf.global_variables_initializer())
#out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
#out = sess.run(output)
......
......@@ -46,7 +46,7 @@ class NewSession(SessionInit):
initializer.
"""
def _init(self, sess):
sess.run(tf.initialize_all_variables())
sess.run(tf.global_variables_initializer())
class SaverRestore(SessionInit):
"""
......@@ -85,10 +85,7 @@ class SaverRestore(SessionInit):
for dic in SaverRestore._produce_restore_dict(vars_map):
# multiple saver under same name scope would cause error:
# training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name
try:
saver = tf.train.Saver(var_list=dic, name=str(id(dic)), write_version=2)
except:
saver = tf.train.Saver(var_list=dic, name=str(id(dic)))
saver = tf.train.Saver(var_list=dic, name=str(id(dic)), write_version=2)
saver.restore(sess, self.path)
def set_path(self, model_path):
......@@ -124,10 +121,7 @@ class SaverRestore(SessionInit):
:param vars_available: varaible names available in the checkpoint, for existence checking
:returns: a dict of {var_name: [var, var]} to restore
"""
try:
vars_to_restore = tf.global_variables()
except AttributeError:
vars_to_restore = tf.all_variables()
vars_to_restore = tf.global_variables()
var_dict = defaultdict(list)
chkpt_vars_used = set()
for v in vars_to_restore:
......
......@@ -119,10 +119,7 @@ class Trainer(object):
logger.info("Initializing graph variables ...")
# TODO newsession + sessinit?
try:
initop = tf.global_variables_initializer()
except:
initop = tf.initialize_all_variables()
initop = tf.global_variables_initializer()
self.sess.run(initop)
self.config.session_init.init(self.sess)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment