Commit 4e7977f5 authored by Yuxin Wu's avatar Yuxin Wu

remove all_variables

parent f1fdb42e
...@@ -43,12 +43,7 @@ class GraphVarParam(HyperParam): ...@@ -43,12 +43,7 @@ class GraphVarParam(HyperParam):
self._readable_name, self.var_name = get_op_var_name(name) self._readable_name, self.var_name = get_op_var_name(name)
def setup_graph(self): def setup_graph(self):
try: all_vars = tf.global_variables()
all_vars = tf.global_variables()
except:
# TODO
all_vars = tf.all_variables()
for v in all_vars: for v in all_vars:
if v.name == self.var_name: if v.name == self.var_name:
self.var = v self.var = v
......
...@@ -10,7 +10,7 @@ import unittest ...@@ -10,7 +10,7 @@ import unittest
class TestModel(unittest.TestCase): class TestModel(unittest.TestCase):
def run_variable(self, var): def run_variable(self, var):
sess = tf.Session() sess = tf.Session()
sess.run(tf.initialize_all_variables()) sess.run(tf.global_variables_initializer())
if isinstance(var, list): if isinstance(var, list):
return sess.run(var) return sess.run(var)
else: else:
......
...@@ -151,7 +151,7 @@ if __name__ == '__main__': ...@@ -151,7 +151,7 @@ if __name__ == '__main__':
mapv = tf.Variable(mapping) mapv = tf.Variable(mapping)
output = ImageSample('sample', [imv, mapv], borderMode='constant') output = ImageSample('sample', [imv, mapv], borderMode='constant')
sess = tf.Session() sess = tf.Session()
sess.run(tf.initialize_all_variables()) sess.run(tf.global_variables_initializer())
#out = sess.run(tf.gradients(tf.reduce_sum(output), mapv)) #out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
#out = sess.run(output) #out = sess.run(output)
......
...@@ -46,7 +46,7 @@ class NewSession(SessionInit): ...@@ -46,7 +46,7 @@ class NewSession(SessionInit):
initializer. initializer.
""" """
def _init(self, sess): def _init(self, sess):
sess.run(tf.initialize_all_variables()) sess.run(tf.global_variables_initializer())
class SaverRestore(SessionInit): class SaverRestore(SessionInit):
""" """
...@@ -85,10 +85,7 @@ class SaverRestore(SessionInit): ...@@ -85,10 +85,7 @@ class SaverRestore(SessionInit):
for dic in SaverRestore._produce_restore_dict(vars_map): for dic in SaverRestore._produce_restore_dict(vars_map):
# multiple saver under same name scope would cause error: # multiple saver under same name scope would cause error:
# training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name # training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name
try: saver = tf.train.Saver(var_list=dic, name=str(id(dic)), write_version=2)
saver = tf.train.Saver(var_list=dic, name=str(id(dic)), write_version=2)
except:
saver = tf.train.Saver(var_list=dic, name=str(id(dic)))
saver.restore(sess, self.path) saver.restore(sess, self.path)
def set_path(self, model_path): def set_path(self, model_path):
...@@ -124,10 +121,7 @@ class SaverRestore(SessionInit): ...@@ -124,10 +121,7 @@ class SaverRestore(SessionInit):
:param vars_available: varaible names available in the checkpoint, for existence checking :param vars_available: varaible names available in the checkpoint, for existence checking
:returns: a dict of {var_name: [var, var]} to restore :returns: a dict of {var_name: [var, var]} to restore
""" """
try: vars_to_restore = tf.global_variables()
vars_to_restore = tf.global_variables()
except AttributeError:
vars_to_restore = tf.all_variables()
var_dict = defaultdict(list) var_dict = defaultdict(list)
chkpt_vars_used = set() chkpt_vars_used = set()
for v in vars_to_restore: for v in vars_to_restore:
......
...@@ -119,10 +119,7 @@ class Trainer(object): ...@@ -119,10 +119,7 @@ class Trainer(object):
logger.info("Initializing graph variables ...") logger.info("Initializing graph variables ...")
# TODO newsession + sessinit? # TODO newsession + sessinit?
try: initop = tf.global_variables_initializer()
initop = tf.global_variables_initializer()
except:
initop = tf.initialize_all_variables()
self.sess.run(initop) self.sess.run(initop)
self.config.session_init.init(self.sess) self.config.session_init.init(self.sess)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment