Commit 4942ef45 authored by Yuxin Wu's avatar Yuxin Wu

improve both sessioninit

parent 3b2f7df1
......@@ -121,7 +121,7 @@ class InferenceRunner(Callback):
class ScalarStats(Inferencer):
"""
Write stat and summary of some scalar tensor.
Write some scalar tensor to both stat and summary.
The output of the given Ops must be a scalar.
The value will be averaged over all data points in the dataset.
"""
......
......@@ -42,7 +42,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
assert n_out is not None
beta = tf.get_variable('beta', [n_out])
gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(1.0))
initializer=tf.ones_initializer)
if len(shape) == 2:
batch_mean, batch_var = tf.nn.moments(x, [0], keep_dims=False)
......
......@@ -96,4 +96,3 @@ class ModelFromMetaGraph(ModelDesc):
def _build_graph(self, _, __):
""" Do nothing. Graph was imported already """
pass
......@@ -6,6 +6,7 @@ import tensorflow as tf
from collections import namedtuple
from six.moves import zip
from tensorpack.models import ModelDesc
from ..tfutils import *
import multiprocessing
......@@ -53,8 +54,10 @@ class PredictConfig(object):
# XXX does it work? start with minimal memory, but allow growth.
# allow_growth doesn't seem to work very well in TF.
self.session_config = kwargs.pop('session_config', get_default_sess_config(0.3))
self.session_init = kwargs.pop('session_init')
self.session_init = kwargs.pop('session_init', JustCurrentSession())
assert_type(self.session_init, SessionInit)
self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc)
self.input_data_mapping = kwargs.pop('input_data_mapping', None)
self.output_var_names = kwargs.pop('output_var_names')
self.return_input = kwargs.pop('return_input', False)
......@@ -86,4 +89,5 @@ def get_predict_func(config):
def run_input(dp):
feed = dict(zip(input_map, dp))
return sess.run(output_vars, feed_dict=feed)
run_input.session = sess
return run_input
......@@ -68,7 +68,9 @@ class SaverRestore(SessionInit):
chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
vars_map = SaverRestore._get_vars_to_restore_multimap(chkpt_vars)
for dic in SaverRestore._produce_restore_dict(vars_map):
saver = tf.train.Saver(var_list=dic)
# multiple saver under same name scope would cause error:
# training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name
saver = tf.train.Saver(var_list=dic, name=str(id(dic)))
saver.restore(sess, self.path)
def set_path(self, model_path):
......@@ -148,7 +150,10 @@ class ParamRestore(SessionInit):
"{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during loading!".format(name))
value = value.reshape(varshape)
sess.run(var.assign(value))
# assign(value) creates ops with values being saved, doubling the size of metagraph
# assign(placeholder) works better here
p = tf.placeholder(value.dtype, shape=value.shape)
sess.run(var.assign(p), feed_dict={p:value})
def ChainInit(SessionInit):
""" Init a session by a list of SessionInit instance."""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment