Commit af2c0e9c authored by Yuxin Wu's avatar Yuxin Wu

session update

parent 3a431489
...@@ -11,6 +11,8 @@ import tensorflow as tf ...@@ -11,6 +11,8 @@ import tensorflow as tf
import six import six
from ..utils import logger, EXTRA_SAVE_VARS_KEY from ..utils import logger, EXTRA_SAVE_VARS_KEY
from .common import get_op_var_name
from .sessupdate import SessionUpdate
__all__ = ['SessionInit', 'NewSession', 'SaverRestore', __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore', 'ChainInit', 'ParamRestore', 'ChainInit',
...@@ -142,35 +144,26 @@ class ParamRestore(SessionInit): ...@@ -142,35 +144,26 @@ class ParamRestore(SessionInit):
""" """
:param param_dict: a dict of {name: value} :param param_dict: a dict of {name: value}
""" """
self.prms = param_dict self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)}
def _init(self, sess): def _init(self, sess):
# allow restore non-trainable variables
variables = tf.get_collection(tf.GraphKeys.VARIABLES) variables = tf.get_collection(tf.GraphKeys.VARIABLES)
var_dict = dict([v.name, v] for v in variables)
for name, value in six.iteritems(self.prms): variable_names = set([k.name for k in variables])
if not name.endswith(':0'): param_names = set(six.iterkeys(self.prms))
name = name + ':0'
try: intersect = variable_names and param_names
var = var_dict[name]
except (ValueError, KeyError): logger.info("Params to restore: {}".format(
logger.warn("Param {} not found in this graph".format(name)) ', '.join(map(str, intersect))))
continue for k in variable_names - param_names:
del var_dict[name] logger.warn("Variable {} in the graph won't be restored!".format(k))
logger.info("Restoring param {}".format(name)) for k in param_names - variable_names:
varshape = tuple(var.get_shape().as_list()) logger.warn("Param {} not found in this graph!".format(k))
if varshape != value.shape: upd = SessionUpdate(sess, [v for v in variables if v.name in intersect])
# TODO only allow reshape when set(shape) is the same or different by 1 logger.info("Restoring from param dict ...")
assert np.prod(varshape) == np.prod(value.shape), \ upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect})
"{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during loading!".format(name))
value = value.reshape(varshape)
# assign(value) creates ops with values being saved, doubling the size of metagraph
# assign(placeholder) works better here
p = tf.placeholder(value.dtype, shape=value.shape)
sess.run(var.assign(p), feed_dict={p:value})
if var_dict:
logger.warn("Some variables in the graph are not restored: {}".format(str(var_dict)))
def ChainInit(SessionInit): def ChainInit(SessionInit):
""" Init a session by a list of SessionInit instance.""" """ Init a session by a list of SessionInit instance."""
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: sessupdate.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import six
import tensorflow as tf
__all__ = ['SessionUpdate']
class SessionUpdate(object):
""" Update the variables in a session """
def __init__(self, sess, vars_to_update):
"""
:param vars_to_update: a collection of variables to update
"""
self.sess = sess
self.assign_ops = {}
for v in vars_to_update:
p = tf.placeholder(v.dtype, shape=v.get_shape())
self.assign_ops[v.name] = (p, v.assign(p))
def update(self, prms):
"""
:param prms: dict of {variable name: value}
Any name in prms must be in the graph and in vars_to_update.
"""
for name, value in six.iteritems(prms):
p, op = self.assign_ops[name]
varshape = tuple(p.get_shape().as_list())
if varshape != value.shape:
# TODO only allow reshape when shape different by empty axis
assert np.prod(varshape) == np.prod(value.shape), \
"{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during assigning".format(name))
value = value.reshape(varshape)
self.sess.run(op, feed_dict={p: value})
...@@ -104,6 +104,7 @@ def summary_moving_average(): ...@@ -104,6 +104,7 @@ def summary_moving_average():
vars_to_summary = tf.get_collection(MOVING_SUMMARY_VARS_KEY) vars_to_summary = tf.get_collection(MOVING_SUMMARY_VARS_KEY)
avg_maintain_op = averager.apply(vars_to_summary) avg_maintain_op = averager.apply(vars_to_summary)
for idx, c in enumerate(vars_to_summary): for idx, c in enumerate(vars_to_summary):
# TODO assert scalar
name = re.sub('tower[p0-9]+/', '', c.op.name) name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.scalar_summary(name, averager.average(c)) tf.scalar_summary(name, averager.average(c))
return avg_maintain_op return avg_maintain_op
......
...@@ -62,7 +62,7 @@ class SimpleTrainer(Trainer): ...@@ -62,7 +62,7 @@ class SimpleTrainer(Trainer):
model = self.model model = self.model
self.input_vars = model.get_input_vars() self.input_vars = model.get_input_vars()
model.build_graph(self.input_vars, True) model.build_graph(self.input_vars, True)
cost_var = model.get_cost() cost_var = model.get_cost() # TODO assert scalar
add_moving_summary(cost_var) add_moving_summary(cost_var)
grads = self.config.optimizer.compute_gradients(cost_var) grads = self.config.optimizer.compute_gradients(cost_var)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment