Commit 8b4d4f77 authored by Yuxin Wu's avatar Yuxin Wu

dump-model without GPU. some checks on windows support

parent c8a9e4e5
...@@ -3,11 +3,14 @@ ...@@ -3,11 +3,14 @@
# File: dump-model-params.py # File: dump-model-params.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import six
import argparse import argparse
import os
import tensorflow as tf import tensorflow as tf
from tensorpack import logger from tensorpack.tfutils import varmanip
from tensorpack.tfutils import varmanip, get_model_loader from tensorpack.tfutils.common import get_op_tensor_name
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -17,31 +20,27 @@ if __name__ == '__main__': ...@@ -17,31 +20,27 @@ if __name__ == '__main__':
parser.add_argument(dest='output', help='output model file, can be npz or TF checkpoint') parser.add_argument(dest='output', help='output model file, can be npz or TF checkpoint')
args = parser.parse_args() args = parser.parse_args()
# this script does not need GPU
os.environ['CUDA_VISIBLE_DEVICES'] = ''
tf.train.import_meta_graph(args.meta, clear_devices=True) tf.train.import_meta_graph(args.meta, clear_devices=True)
# loading... # loading...
init = get_model_loader(args.input) if args.input.endswith('.npz'):
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) dic = np.load(args.input)
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
init.init(sess)
# dump ...
with sess.as_default():
if args.output.endswith('npy') or args.output.endswith('npz'):
varmanip.dump_session_params(args.output)
else: else:
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) dic = varmanip.load_chkpt_vars(args.input)
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)) dic = {get_op_tensor_name(k)[1]: v for k, v in six.iteritems(dic)}
gvars = set([k.name for k in tf.global_variables()])
var = [v for v in var if v.name in gvars] # save variables that are GLOBAL, and either TRAINABLE or MODEL
var_dict = {} var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
for v in var: var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
name = varmanip.get_savename_from_varname(v.name) assert len(set(var_to_dump)) == len(var_to_dump), "TRAINABLE and MODEL variables have duplication!"
var_dict[name] = v globvarname = [k.name for k in tf.global_variables()]
logger.info("Variables to dump:") var_to_dump = set([k.name for k in var_to_dump if k.name in globvarname])
logger.info(", ".join(var_dict.keys()))
saver = tf.train.Saver( for name in var_to_dump:
var_list=var_dict, assert name in dic, "Variable {} not found in the model!".format(name)
write_version=tf.train.SaverDef.V2)
saver.save(sess, args.output, write_meta_graph=False) dic_to_dump = {k: v for k, v in six.iteritems(dic) if k in var_to_dump}
varmanip.save_chkpt_vars(dic_to_dump, args.output)
...@@ -35,6 +35,7 @@ class GPUUtilizationTracker(Callback): ...@@ -35,6 +35,7 @@ class GPUUtilizationTracker(Callback):
Args: Args:
devices (list[int]): physical GPU ids. If None, will use CUDA_VISIBLE_DEVICES devices (list[int]): physical GPU ids. If None, will use CUDA_VISIBLE_DEVICES
""" """
assert os.name != 'nt', "GPUUtilizationTracker does not support windows!"
if devices is None: if devices is None:
env = os.environ.get('CUDA_VISIBLE_DEVICES') env = os.environ.get('CUDA_VISIBLE_DEVICES')
if env is None: if env is None:
......
...@@ -166,6 +166,8 @@ class MultiProcessPrefetchData(ProxyDataFlow): ...@@ -166,6 +166,8 @@ class MultiProcessPrefetchData(ProxyDataFlow):
nr_prefetch (int): size of the queue to hold prefetched datapoints. nr_prefetch (int): size of the queue to hold prefetched datapoints.
nr_proc (int): number of processes to use. nr_proc (int): number of processes to use.
""" """
if os.name == 'nt':
logger.warn("MultiProcessPrefetchData may not support windows!")
super(MultiProcessPrefetchData, self).__init__(ds) super(MultiProcessPrefetchData, self).__init__(ds)
try: try:
self._size = ds.size() self._size = ds.size()
......
...@@ -117,6 +117,7 @@ def dump_session_params(path): ...@@ -117,6 +117,7 @@ def dump_session_params(path):
Args: Args:
path(str): the file name to save the parameters. Must ends with npz. path(str): the file name to save the parameters. Must ends with npz.
""" """
# save variables that are GLOBAL, and either TRAINABLE or MODEL
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)) var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
# TODO dedup # TODO dedup
...@@ -126,15 +127,33 @@ def dump_session_params(path): ...@@ -126,15 +127,33 @@ def dump_session_params(path):
result = {} result = {}
for v in var: for v in var:
result[v.name] = v.eval() result[v.name] = v.eval()
save_chkpt_vars(result, path)
def save_chkpt_vars(dic, path):
"""
Save variables in dic to path.
Args:
dic: {name: value}
path: save as npz if the name ends with '.npz', otherwise save as a checkpoint.
"""
logger.info("Variables to save to {}:".format(path)) logger.info("Variables to save to {}:".format(path))
keys = sorted(list(result.keys())) keys = sorted(list(dic.keys()))
logger.info(pprint.pformat(keys)) logger.info(pprint.pformat(keys))
if path.endswith('.npy'):
np.save(path, result) assert not path.endswith('.npy')
elif path.endswith('.npz'): if path.endswith('.npz'):
np.savez_compressed(path, **result) np.savez_compressed(path, **dic)
else: else:
raise ValueError("Don't know which format to use for {}".format(path)) with tf.Graph().as_default(), \
tf.Session() as sess:
for k, v in six.iteritems(dic):
k = get_op_tensor_name(k)[0]
_ = tf.Variable(name=k, initial_value=v) # noqa
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(sess, path, write_meta_graph=False)
def get_checkpoint_path(model_path): def get_checkpoint_path(model_path):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment