Commit 8b4d4f77 authored by Yuxin Wu's avatar Yuxin Wu

dump-model without GPU. some checks on windows support

parent c8a9e4e5
......@@ -3,11 +3,14 @@
# File: dump-model-params.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import six
import argparse
import os
import tensorflow as tf
from tensorpack import logger
from tensorpack.tfutils import varmanip, get_model_loader
from tensorpack.tfutils import varmanip
from tensorpack.tfutils.common import get_op_tensor_name
if __name__ == '__main__':
parser = argparse.ArgumentParser(
......@@ -17,31 +20,27 @@ if __name__ == '__main__':
parser.add_argument(dest='output', help='output model file, can be npz or TF checkpoint')
args = parser.parse_args()
# this script does not need GPU
os.environ['CUDA_VISIBLE_DEVICES'] = ''
tf.train.import_meta_graph(args.meta, clear_devices=True)
# loading...
init = get_model_loader(args.input)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
init.init(sess)
# dump ...
with sess.as_default():
if args.output.endswith('npy') or args.output.endswith('npz'):
varmanip.dump_session_params(args.output)
if args.input.endswith('.npz'):
dic = np.load(args.input)
else:
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
gvars = set([k.name for k in tf.global_variables()])
var = [v for v in var if v.name in gvars]
var_dict = {}
for v in var:
name = varmanip.get_savename_from_varname(v.name)
var_dict[name] = v
logger.info("Variables to dump:")
logger.info(", ".join(var_dict.keys()))
saver = tf.train.Saver(
var_list=var_dict,
write_version=tf.train.SaverDef.V2)
saver.save(sess, args.output, write_meta_graph=False)
dic = varmanip.load_chkpt_vars(args.input)
dic = {get_op_tensor_name(k)[1]: v for k, v in six.iteritems(dic)}
# save variables that are GLOBAL, and either TRAINABLE or MODEL
var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
assert len(set(var_to_dump)) == len(var_to_dump), "TRAINABLE and MODEL variables have duplication!"
globvarname = [k.name for k in tf.global_variables()]
var_to_dump = set([k.name for k in var_to_dump if k.name in globvarname])
for name in var_to_dump:
assert name in dic, "Variable {} not found in the model!".format(name)
dic_to_dump = {k: v for k, v in six.iteritems(dic) if k in var_to_dump}
varmanip.save_chkpt_vars(dic_to_dump, args.output)
......@@ -35,6 +35,7 @@ class GPUUtilizationTracker(Callback):
Args:
devices (list[int]): physical GPU ids. If None, will use CUDA_VISIBLE_DEVICES
"""
assert os.name != 'nt', "GPUUtilizationTracker does not support windows!"
if devices is None:
env = os.environ.get('CUDA_VISIBLE_DEVICES')
if env is None:
......
......@@ -166,6 +166,8 @@ class MultiProcessPrefetchData(ProxyDataFlow):
nr_prefetch (int): size of the queue to hold prefetched datapoints.
nr_proc (int): number of processes to use.
"""
if os.name == 'nt':
logger.warn("MultiProcessPrefetchData may not support windows!")
super(MultiProcessPrefetchData, self).__init__(ds)
try:
self._size = ds.size()
......
......@@ -117,6 +117,7 @@ def dump_session_params(path):
Args:
path(str): the file name to save the parameters. Must ends with npz.
"""
# save variables that are GLOBAL, and either TRAINABLE or MODEL
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
# TODO dedup
......@@ -126,15 +127,33 @@ def dump_session_params(path):
result = {}
for v in var:
result[v.name] = v.eval()
save_chkpt_vars(result, path)
def save_chkpt_vars(dic, path):
"""
Save variables in dic to path.
Args:
dic: {name: value}
path: save as npz if the name ends with '.npz', otherwise save as a checkpoint.
"""
logger.info("Variables to save to {}:".format(path))
keys = sorted(list(result.keys()))
keys = sorted(list(dic.keys()))
logger.info(pprint.pformat(keys))
if path.endswith('.npy'):
np.save(path, result)
elif path.endswith('.npz'):
np.savez_compressed(path, **result)
assert not path.endswith('.npy')
if path.endswith('.npz'):
np.savez_compressed(path, **dic)
else:
raise ValueError("Don't know which format to use for {}".format(path))
with tf.Graph().as_default(), \
tf.Session() as sess:
for k, v in six.iteritems(dic):
k = get_op_tensor_name(k)[0]
_ = tf.Variable(name=k, initial_value=v) # noqa
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(sess, path, write_meta_graph=False)
def get_checkpoint_path(model_path):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment