Commit 35f24d40 authored by Yuxin Wu's avatar Yuxin Wu

varmanip

parent f8d4352a
#!/usr/bin/env python2 #!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: dump_model_params.py # File: dump-model-params.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import argparse import argparse
import cv2
import tensorflow as tf import tensorflow as tf
import imp import imp
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.tfutils import sessinit from tensorpack.tfutils import sessinit, varmanip
from tensorpack.dataflow import * from tensorpack.dataflow import *
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -27,4 +26,4 @@ with tf.Graph().as_default() as G: ...@@ -27,4 +26,4 @@ with tf.Graph().as_default() as G:
sess = tf.Session() sess = tf.Session()
init.init(sess) init.init(sess)
with sess.as_default(): with sess.as_default():
sessinit.dump_session_params(args.output) varmanip.dump_session_params(args.output)
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import os import os
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import numpy as np
from collections import defaultdict from collections import defaultdict
import re import re
import tensorflow as tf import tensorflow as tf
...@@ -12,12 +11,11 @@ import six ...@@ -12,12 +11,11 @@ import six
from ..utils import logger, EXTRA_SAVE_VARS_KEY from ..utils import logger, EXTRA_SAVE_VARS_KEY
from .common import get_op_var_name from .common import get_op_var_name
from .sessupdate import SessionUpdate from .varmanip import SessionUpdate
__all__ = ['SessionInit', 'NewSession', 'SaverRestore', __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore', 'ChainInit', 'ParamRestore', 'ChainInit',
'JustCurrentSession', 'JustCurrentSession']
'dump_session_params']
# TODO they initialize_all at the beginning by default. # TODO they initialize_all at the beginning by default.
...@@ -180,17 +178,3 @@ def ChainInit(SessionInit): ...@@ -180,17 +178,3 @@ def ChainInit(SessionInit):
def _init(self, sess): def _init(self, sess):
for i in self.inits: for i in self.inits:
i.init(sess) i.init(sess)
def dump_session_params(path):
""" Dump value of all trainable variables to a dict and save to `path` as
npy format, loadable by ParamRestore
"""
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY))
result = {}
for v in var:
name = v.name.replace(":0", "")
result[name] = v.eval()
logger.info("Variables to save to {}:".format(path))
logger.info(str(result.keys()))
np.save(path, result)
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: sessupdate.py # File: varmanip.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import six import six
import tensorflow as tf import tensorflow as tf
import numpy as np
__all__ = ['SessionUpdate'] __all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars']
class SessionUpdate(object): class SessionUpdate(object):
""" Update the variables in a session """ """ Update the variables in a session """
...@@ -35,3 +36,28 @@ class SessionUpdate(object): ...@@ -35,3 +36,28 @@ class SessionUpdate(object):
logger.warn("Param {} is reshaped during assigning".format(name)) logger.warn("Param {} is reshaped during assigning".format(name))
value = value.reshape(varshape) value = value.reshape(varshape)
self.sess.run(op, feed_dict={p: value}) self.sess.run(op, feed_dict={p: value})
def dump_session_params(path):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as
npy format, loadable by ParamRestore
"""
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY))
result = {}
for v in var:
name = v.name.replace(":0", "")
result[name] = v.eval()
logger.info("Variables to save to {}:".format(path))
logger.info(str(result.keys()))
np.save(path, result)
def dump_chkpt_vars(model_path, output):
""" Dump all variables from a checkpoint """
reader = tf.train.NewCheckpointReader(model_path)
var_names = reader.get_variable_to_shape_map().keys()
result = {}
for n in var_names:
result[n] = reader.get_tensor(n)
logger.info("Variables to save to {}:".format(output))
logger.info(str(result.keys()))
np.save(output, result)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment