Commit e68eec29 authored by Yuxin Wu's avatar Yuxin Wu

better warning about user-provided sessconfig

parent 6a0bba68
...@@ -75,13 +75,13 @@ if __name__ == '__main__': ...@@ -75,13 +75,13 @@ if __name__ == '__main__':
if os.path.isdir(args.input): if os.path.isdir(args.input):
input, meta = guess_inputs(args.input) input, meta = guess_inputs(args.input)
else: else:
assert args.meta is not None
meta = args.meta meta = args.meta
input = args.input input = args.input
# this script does not need GPU # this script does not need GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '' os.environ['CUDA_VISIBLE_DEVICES'] = ''
if args.meta is not None:
while True: while True:
try: try:
tf.reset_default_graph() tf.reset_default_graph()
...@@ -101,6 +101,7 @@ if __name__ == '__main__': ...@@ -101,6 +101,7 @@ if __name__ == '__main__':
dic = varmanip.load_chkpt_vars(input) dic = varmanip.load_chkpt_vars(input)
dic = {get_op_tensor_name(k)[1]: v for k, v in six.iteritems(dic)} dic = {get_op_tensor_name(k)[1]: v for k, v in six.iteritems(dic)}
if args.meta is not None:
# save variables that are GLOBAL, and either TRAINABLE or MODEL # save variables that are GLOBAL, and either TRAINABLE or MODEL
var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)) var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
...@@ -112,6 +113,8 @@ if __name__ == '__main__': ...@@ -112,6 +113,8 @@ if __name__ == '__main__':
for name in var_to_dump: for name in var_to_dump:
assert name in dic, "Variable {} not found in the model!".format(name) assert name in dic, "Variable {} not found in the model!".format(name)
else:
var_to_dump = set(dic.keys())
dic_to_dump = {k: v for k, v in six.iteritems(dic) if k in var_to_dump} dic_to_dump = {k: v for k, v in six.iteritems(dic) if k in var_to_dump}
varmanip.save_chkpt_vars(dic_to_dump, args.output) varmanip.save_chkpt_vars(dic_to_dump, args.output)
...@@ -17,6 +17,21 @@ A SessionCreator should: ...@@ -17,6 +17,21 @@ A SessionCreator should:
""" """
_WRN1 = """User-provided custom session config may not work due to TF bugs. If you saw logs like
```
tensorflow/core/common_runtime/gpu/gpu_device.cc:1433] Found device 0 with properties:
```
before this line, then your GPU has been initialized and custom GPU options may not take effect. """
_WRN2 = """To workaround this issue, you can do one of the following:
1. Avoid initializing the GPU too early. Find code that initializes the GPU and skip it.
Typically examples are: creating a session; check GPU availability; check GPU number.
2. Manually set your GPU options earlier. You can create a session with custom
GPU options at the beginning of your program, as described in
https://github.com/tensorpack/tensorpack/issues/497
"""
class NewSessionCreator(tf.train.SessionCreator): class NewSessionCreator(tf.train.SessionCreator):
def __init__(self, target='', config=None): def __init__(self, target='', config=None):
""" """
...@@ -33,9 +48,8 @@ class NewSessionCreator(tf.train.SessionCreator): ...@@ -33,9 +48,8 @@ class NewSessionCreator(tf.train.SessionCreator):
config = get_default_sess_config() config = get_default_sess_config()
else: else:
self.user_provided_config = True self.user_provided_config = True
logger.warn( logger.warn(_WRN1)
"User-provided custom session config may not work due to TF \ logger.warn(_WRN2)
bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
self.config = config self.config = config
def create_session(self): def create_session(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment