Commit e68eec29 authored by Yuxin Wu's avatar Yuxin Wu

better warning about user-provided sessconfig

parent 6a0bba68
......@@ -75,24 +75,24 @@ if __name__ == '__main__':
if os.path.isdir(args.input):
input, meta = guess_inputs(args.input)
else:
assert args.meta is not None
meta = args.meta
input = args.input
# this script does not need GPU
os.environ['CUDA_VISIBLE_DEVICES'] = ''
while True:
try:
tf.reset_default_graph()
tf.train.import_meta_graph(meta, clear_devices=True)
except KeyError as e:
op_name = e.args[0]
_import_external_ops(op_name)
except tf.errors.NotFoundError as e:
_import_external_ops(e.message)
else:
break
if args.meta is not None:
while True:
try:
tf.reset_default_graph()
tf.train.import_meta_graph(meta, clear_devices=True)
except KeyError as e:
op_name = e.args[0]
_import_external_ops(op_name)
except tf.errors.NotFoundError as e:
_import_external_ops(e.message)
else:
break
# loading...
if input.endswith('.npz'):
......@@ -101,17 +101,20 @@ if __name__ == '__main__':
dic = varmanip.load_chkpt_vars(input)
dic = {get_op_tensor_name(k)[1]: v for k, v in six.iteritems(dic)}
# save variables that are GLOBAL, and either TRAINABLE or MODEL
var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
if len(set(var_to_dump)) != len(var_to_dump):
logger.warn("TRAINABLE and MODEL variables have duplication!")
var_to_dump = list(set(var_to_dump))
globvarname = set([k.name for k in tf.global_variables()])
var_to_dump = set([k.name for k in var_to_dump if k.name in globvarname])
for name in var_to_dump:
assert name in dic, "Variable {} not found in the model!".format(name)
if args.meta is not None:
# save variables that are GLOBAL, and either TRAINABLE or MODEL
var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
if len(set(var_to_dump)) != len(var_to_dump):
logger.warn("TRAINABLE and MODEL variables have duplication!")
var_to_dump = list(set(var_to_dump))
globvarname = set([k.name for k in tf.global_variables()])
var_to_dump = set([k.name for k in var_to_dump if k.name in globvarname])
for name in var_to_dump:
assert name in dic, "Variable {} not found in the model!".format(name)
else:
var_to_dump = set(dic.keys())
dic_to_dump = {k: v for k, v in six.iteritems(dic) if k in var_to_dump}
varmanip.save_chkpt_vars(dic_to_dump, args.output)
......@@ -17,6 +17,21 @@ A SessionCreator should:
"""
_WRN1 = """User-provided custom session config may not work due to TF bugs. If you saw logs like
```
tensorflow/core/common_runtime/gpu/gpu_device.cc:1433] Found device 0 with properties:
```
before this line, then your GPU has been initialized and custom GPU options may not take effect. """
_WRN2 = """To workaround this issue, you can do one of the following:
1. Avoid initializing the GPU too early. Find code that initializes the GPU and skip it.
Typically examples are: creating a session; check GPU availability; check GPU number.
2. Manually set your GPU options earlier. You can create a session with custom
GPU options at the beginning of your program, as described in
https://github.com/tensorpack/tensorpack/issues/497
"""
class NewSessionCreator(tf.train.SessionCreator):
def __init__(self, target='', config=None):
"""
......@@ -33,9 +48,8 @@ class NewSessionCreator(tf.train.SessionCreator):
config = get_default_sess_config()
else:
self.user_provided_config = True
logger.warn(
"User-provided custom session config may not work due to TF \
bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
logger.warn(_WRN1)
logger.warn(_WRN2)
self.config = config
def create_session(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment