Commit 65449110 authored by Yuxin Wu's avatar Yuxin Wu

smarter dump-model-params.py

parent 891dc488
...@@ -9,55 +9,104 @@ import os ...@@ -9,55 +9,104 @@ import os
import six import six
import tensorflow as tf import tensorflow as tf
from tensorpack import logger
from tensorpack.tfutils import varmanip from tensorpack.tfutils import varmanip
from tensorpack.tfutils.common import get_op_tensor_name from tensorpack.tfutils.common import get_op_tensor_name, get_tf_version_tuple
TF_version = get_tf_version_tuple()
def _import_external_ops(op_name):
if "horovod" in op_name.lower(): def _import_external_ops(message):
if "horovod" in message.lower():
logger.info("Importing horovod ...")
import horovod.tensorflow # noqa import horovod.tensorflow # noqa
return return
if op_name == "MaxBytesInUse": if "MaxBytesInUse" in message:
logger.info("Importing memory_stats ...")
from tensorflow.contrib.memory_stats import MaxBytesInUse # noqa from tensorflow.contrib.memory_stats import MaxBytesInUse # noqa
return return
print("Your graph contains op '{}' which is not loaded into your Tensorflow runtime.".format(op_name)) if 'Nccl' in message:
print("Therefore the graph cannot be loaded unless you import the relevant libraries first.") logger.info("Importing nccl ...")
sys.exit(1) if TF_version <= (1, 12):
try:
from tensorflow.contrib.nccl.python.ops.nccl_ops import _validate_and_load_nccl_so
except Exception:
pass
else:
_validate_and_load_nccl_so()
from tensorflow.contrib.nccl.ops import gen_nccl_ops
else:
from tensorflow.python.ops import gen_nccl_ops
return
def guess_inputs(input_dir):
meta_candidates = []
model_candidates = []
for path in os.listdir(input_dir):
if path.startswith('graph-') and path.endswith('.meta'):
meta_candidates.append(path)
if path.startswith('model-') and path.endswith('.index'):
modelid = int(path[len('model-'):-len('.index')])
model_candidates.append((path, modelid))
assert len(meta_candidates)
meta = sorted(meta_candidates)[-1]
if len(meta_candidates) > 1:
logger.info("Choosing {} from {} as graph file.".format(meta, meta_candidates))
else:
logger.info("Choosing {} as graph file.".format(meta))
assert len(model_candidates)
model = sorted(model_candidates, key=lambda x: x[1])[-1][0]
if len(model_candidates) > 1:
logger.info("Choosing {} from {} as model file.".format(model, [x[0] for x in model_candidates]))
else:
logger.info("Choosing {} as model file.".format(model))
return os.path.join(input_dir, model), os.path.join(input_dir, meta)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Keep only TRAINABLE and MODEL variables in a checkpoint.') description='Keep only TRAINABLE and MODEL variables in a checkpoint.')
parser.add_argument('--meta', help='metagraph file', required=True) parser.add_argument('--meta', help='metagraph file')
parser.add_argument(dest='input', help='input model file, has to be a TF checkpoint') parser.add_argument(dest='input', help='input model file, has to be a TF checkpoint')
parser.add_argument(dest='output', help='output model file, can be npz or TF checkpoint') parser.add_argument(dest='output', help='output model file, can be npz or TF checkpoint')
args = parser.parse_args() args = parser.parse_args()
if os.path.isdir(args.input):
input, meta = guess_inputs(args.input)
else:
assert args.meta is not None
meta = args.meta
input = args.input
# this script does not need GPU # this script does not need GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '' os.environ['CUDA_VISIBLE_DEVICES'] = ''
while True: while True:
try: try:
tf.reset_default_graph() tf.reset_default_graph()
tf.train.import_meta_graph(args.meta, clear_devices=True) tf.train.import_meta_graph(meta, clear_devices=True)
except KeyError as e: except KeyError as e:
op_name = e.args[0] op_name = e.args[0]
_import_external_ops(op_name) _import_external_ops(op_name)
except tf.errors.NotFoundError as e:
_import_external_ops(e.message)
else: else:
break break
# loading... # loading...
if args.input.endswith('.npz'): if input.endswith('.npz'):
dic = np.load(args.input) dic = np.load(input)
else: else:
dic = varmanip.load_chkpt_vars(args.input) dic = varmanip.load_chkpt_vars(input)
dic = {get_op_tensor_name(k)[1]: v for k, v in six.iteritems(dic)} dic = {get_op_tensor_name(k)[1]: v for k, v in six.iteritems(dic)}
# save variables that are GLOBAL, and either TRAINABLE or MODEL # save variables that are GLOBAL, and either TRAINABLE or MODEL
var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)) var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
if len(set(var_to_dump)) != len(var_to_dump): if len(set(var_to_dump)) != len(var_to_dump):
print("TRAINABLE and MODEL variables have duplication!") logger.warn("TRAINABLE and MODEL variables have duplication!")
var_to_dump = list(set(var_to_dump)) var_to_dump = list(set(var_to_dump))
globvarname = set([k.name for k in tf.global_variables()]) globvarname = set([k.name for k in tf.global_variables()])
var_to_dump = set([k.name for k in var_to_dump if k.name in globvarname]) var_to_dump = set([k.name for k in var_to_dump if k.name in globvarname])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment