Commit 4942ef45 authored by Yuxin Wu's avatar Yuxin Wu

improve both sessioninit

parent 3b2f7df1
...@@ -121,7 +121,7 @@ class InferenceRunner(Callback): ...@@ -121,7 +121,7 @@ class InferenceRunner(Callback):
class ScalarStats(Inferencer): class ScalarStats(Inferencer):
""" """
Write stat and summary of some scalar tensor. Write some scalar tensor to both stat and summary.
The output of the given Ops must be a scalar. The output of the given Ops must be a scalar.
The value will be averaged over all data points in the dataset. The value will be averaged over all data points in the dataset.
""" """
......
...@@ -42,7 +42,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -42,7 +42,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
assert n_out is not None assert n_out is not None
beta = tf.get_variable('beta', [n_out]) beta = tf.get_variable('beta', [n_out])
gamma = tf.get_variable('gamma', [n_out], gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(1.0)) initializer=tf.ones_initializer)
if len(shape) == 2: if len(shape) == 2:
batch_mean, batch_var = tf.nn.moments(x, [0], keep_dims=False) batch_mean, batch_var = tf.nn.moments(x, [0], keep_dims=False)
......
...@@ -96,4 +96,3 @@ class ModelFromMetaGraph(ModelDesc): ...@@ -96,4 +96,3 @@ class ModelFromMetaGraph(ModelDesc):
def _build_graph(self, _, __): def _build_graph(self, _, __):
""" Do nothing. Graph was imported already """ """ Do nothing. Graph was imported already """
pass pass
...@@ -6,6 +6,7 @@ import tensorflow as tf ...@@ -6,6 +6,7 @@ import tensorflow as tf
from collections import namedtuple from collections import namedtuple
from six.moves import zip from six.moves import zip
from tensorpack.models import ModelDesc
from ..tfutils import * from ..tfutils import *
import multiprocessing import multiprocessing
...@@ -53,8 +54,10 @@ class PredictConfig(object): ...@@ -53,8 +54,10 @@ class PredictConfig(object):
# XXX does it work? start with minimal memory, but allow growth. # XXX does it work? start with minimal memory, but allow growth.
# allow_growth doesn't seem to work very well in TF. # allow_growth doesn't seem to work very well in TF.
self.session_config = kwargs.pop('session_config', get_default_sess_config(0.3)) self.session_config = kwargs.pop('session_config', get_default_sess_config(0.3))
self.session_init = kwargs.pop('session_init') self.session_init = kwargs.pop('session_init', JustCurrentSession())
assert_type(self.session_init, SessionInit)
self.model = kwargs.pop('model') self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc)
self.input_data_mapping = kwargs.pop('input_data_mapping', None) self.input_data_mapping = kwargs.pop('input_data_mapping', None)
self.output_var_names = kwargs.pop('output_var_names') self.output_var_names = kwargs.pop('output_var_names')
self.return_input = kwargs.pop('return_input', False) self.return_input = kwargs.pop('return_input', False)
...@@ -86,4 +89,5 @@ def get_predict_func(config): ...@@ -86,4 +89,5 @@ def get_predict_func(config):
def run_input(dp): def run_input(dp):
feed = dict(zip(input_map, dp)) feed = dict(zip(input_map, dp))
return sess.run(output_vars, feed_dict=feed) return sess.run(output_vars, feed_dict=feed)
run_input.session = sess
return run_input return run_input
...@@ -68,7 +68,9 @@ class SaverRestore(SessionInit): ...@@ -68,7 +68,9 @@ class SaverRestore(SessionInit):
chkpt_vars = SaverRestore._read_checkpoint_vars(self.path) chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
vars_map = SaverRestore._get_vars_to_restore_multimap(chkpt_vars) vars_map = SaverRestore._get_vars_to_restore_multimap(chkpt_vars)
for dic in SaverRestore._produce_restore_dict(vars_map): for dic in SaverRestore._produce_restore_dict(vars_map):
saver = tf.train.Saver(var_list=dic) # multiple saver under same name scope would cause error:
# training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name
saver = tf.train.Saver(var_list=dic, name=str(id(dic)))
saver.restore(sess, self.path) saver.restore(sess, self.path)
def set_path(self, model_path): def set_path(self, model_path):
...@@ -148,7 +150,10 @@ class ParamRestore(SessionInit): ...@@ -148,7 +150,10 @@ class ParamRestore(SessionInit):
"{}: {}!={}".format(name, varshape, value.shape) "{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during loading!".format(name)) logger.warn("Param {} is reshaped during loading!".format(name))
value = value.reshape(varshape) value = value.reshape(varshape)
sess.run(var.assign(value)) # assign(value) creates ops with values being saved, doubling the size of metagraph
# assign(placeholder) works better here
p = tf.placeholder(value.dtype, shape=value.shape)
sess.run(var.assign(p), feed_dict={p:value})
def ChainInit(SessionInit): def ChainInit(SessionInit):
""" Init a session by a list of SessionInit instance.""" """ Init a session by a list of SessionInit instance."""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment