Commit 35f24d40 authored by Yuxin Wu's avatar Yuxin Wu

varmanip

parent f8d4352a
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: dump_model_params.py
# File: dump-model-params.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import argparse
import cv2
import tensorflow as tf
import imp
from tensorpack.utils import *
from tensorpack.tfutils import sessinit
from tensorpack.tfutils import sessinit, varmanip
from tensorpack.dataflow import *
parser = argparse.ArgumentParser()
......@@ -27,4 +26,4 @@ with tf.Graph().as_default() as G:
sess = tf.Session()
init.init(sess)
with sess.as_default():
sessinit.dump_session_params(args.output)
varmanip.dump_session_params(args.output)
......@@ -4,7 +4,6 @@
import os
from abc import abstractmethod, ABCMeta
import numpy as np
from collections import defaultdict
import re
import tensorflow as tf
......@@ -12,12 +11,11 @@ import six
from ..utils import logger, EXTRA_SAVE_VARS_KEY
from .common import get_op_var_name
from .sessupdate import SessionUpdate
from .varmanip import SessionUpdate
__all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore', 'ChainInit',
'JustCurrentSession',
'dump_session_params']
'JustCurrentSession']
# TODO they initialize_all at the beginning by default.
......@@ -180,17 +178,3 @@ def ChainInit(SessionInit):
def _init(self, sess):
for i in self.inits:
i.init(sess)
def dump_session_params(path):
""" Dump value of all trainable variables to a dict and save to `path` as
npy format, loadable by ParamRestore
"""
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY))
result = {}
for v in var:
name = v.name.replace(":0", "")
result[name] = v.eval()
logger.info("Variables to save to {}:".format(path))
logger.info(str(result.keys()))
np.save(path, result)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: sessupdate.py
# File: varmanip.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import six
import tensorflow as tf
import numpy as np
__all__ = ['SessionUpdate']
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars']
class SessionUpdate(object):
""" Update the variables in a session """
......@@ -35,3 +36,28 @@ class SessionUpdate(object):
logger.warn("Param {} is reshaped during assigning".format(name))
value = value.reshape(varshape)
self.sess.run(op, feed_dict={p: value})
def dump_session_params(path):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as
npy format, loadable by ParamRestore
"""
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY))
result = {}
for v in var:
name = v.name.replace(":0", "")
result[name] = v.eval()
logger.info("Variables to save to {}:".format(path))
logger.info(str(result.keys()))
np.save(path, result)
def dump_chkpt_vars(model_path, output):
""" Dump all variables from a checkpoint """
reader = tf.train.NewCheckpointReader(model_path)
var_names = reader.get_variable_to_shape_map().keys()
result = {}
for n in var_names:
result[n] = reader.get_tensor(n)
logger.info("Variables to save to {}:".format(output))
logger.info(str(result.keys()))
np.save(output, result)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment