Commit af2c0e9c authored by Yuxin Wu's avatar Yuxin Wu

session update

parent 3a431489
......@@ -11,6 +11,8 @@ import tensorflow as tf
import six
from ..utils import logger, EXTRA_SAVE_VARS_KEY
from .common import get_op_var_name
from .sessupdate import SessionUpdate
__all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore', 'ChainInit',
......@@ -142,35 +144,26 @@ class ParamRestore(SessionInit):
"""
:param param_dict: a dict of {name: value}
"""
self.prms = param_dict
self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)}
def _init(self, sess):
# allow restore non-trainable variables
variables = tf.get_collection(tf.GraphKeys.VARIABLES)
var_dict = dict([v.name, v] for v in variables)
for name, value in six.iteritems(self.prms):
if not name.endswith(':0'):
name = name + ':0'
try:
var = var_dict[name]
except (ValueError, KeyError):
logger.warn("Param {} not found in this graph".format(name))
continue
del var_dict[name]
logger.info("Restoring param {}".format(name))
varshape = tuple(var.get_shape().as_list())
if varshape != value.shape:
# TODO only allow reshape when set(shape) is the same or different by 1
assert np.prod(varshape) == np.prod(value.shape), \
"{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during loading!".format(name))
value = value.reshape(varshape)
# assign(value) creates ops with values being saved, doubling the size of metagraph
# assign(placeholder) works better here
p = tf.placeholder(value.dtype, shape=value.shape)
sess.run(var.assign(p), feed_dict={p:value})
if var_dict:
logger.warn("Some variables in the graph are not restored: {}".format(str(var_dict)))
variable_names = set([k.name for k in variables])
param_names = set(six.iterkeys(self.prms))
intersect = variable_names and param_names
logger.info("Params to restore: {}".format(
', '.join(map(str, intersect))))
for k in variable_names - param_names:
logger.warn("Variable {} in the graph won't be restored!".format(k))
for k in param_names - variable_names:
logger.warn("Param {} not found in this graph!".format(k))
upd = SessionUpdate(sess, [v for v in variables if v.name in intersect])
logger.info("Restoring from param dict ...")
upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect})
def ChainInit(SessionInit):
""" Init a session by a list of SessionInit instance."""
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: sessupdate.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import six
import tensorflow as tf
__all__ = ['SessionUpdate']
class SessionUpdate(object):
""" Update the variables in a session """
def __init__(self, sess, vars_to_update):
"""
:param vars_to_update: a collection of variables to update
"""
self.sess = sess
self.assign_ops = {}
for v in vars_to_update:
p = tf.placeholder(v.dtype, shape=v.get_shape())
self.assign_ops[v.name] = (p, v.assign(p))
def update(self, prms):
"""
:param prms: dict of {variable name: value}
Any name in prms must be in the graph and in vars_to_update.
"""
for name, value in six.iteritems(prms):
p, op = self.assign_ops[name]
varshape = tuple(p.get_shape().as_list())
if varshape != value.shape:
# TODO only allow reshape when shape different by empty axis
assert np.prod(varshape) == np.prod(value.shape), \
"{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during assigning".format(name))
value = value.reshape(varshape)
self.sess.run(op, feed_dict={p: value})
......@@ -104,6 +104,7 @@ def summary_moving_average():
vars_to_summary = tf.get_collection(MOVING_SUMMARY_VARS_KEY)
avg_maintain_op = averager.apply(vars_to_summary)
for idx, c in enumerate(vars_to_summary):
# TODO assert scalar
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.scalar_summary(name, averager.average(c))
return avg_maintain_op
......
......@@ -62,7 +62,7 @@ class SimpleTrainer(Trainer):
model = self.model
self.input_vars = model.get_input_vars()
model.build_graph(self.input_vars, True)
cost_var = model.get_cost()
cost_var = model.get_cost() # TODO assert scalar
add_moving_summary(cost_var)
grads = self.config.optimizer.compute_gradients(cost_var)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment