Commit 78595e71 authored by Yuxin Wu's avatar Yuxin Wu

significantly improves speed of DictRestore

Loading a large COCO model takes 50 sec -> 0.4 sec
parent 5a868442
......@@ -15,6 +15,9 @@ from config import config as cfg
@layer_register(log_shape=True)
def GroupNorm(x, group=32, gamma_initializer=tf.constant_initializer(1.)):
"""
More code that reproduces the paper can be found at https://github.com/ppwwyyxx/GroupNorm-reproduce/.
"""
shape = x.get_shape().as_list()
ndims = len(shape)
assert ndims == 4, shape
......
......@@ -17,6 +17,7 @@ from imagenet_utils import ImageNetModel, fbresnet_augmentor, get_imagenet_dataf
def GroupNorm(x, group, gamma_initializer=tf.constant_initializer(1.)):
"""
https://arxiv.org/abs/1803.08494
More code that reproduces the paper can be found at https://github.com/ppwwyyxx/GroupNorm-reproduce/.
"""
shape = x.get_shape().as_list()
ndims = len(shape)
......
......@@ -13,7 +13,11 @@ __all__ = ['CallbackToHook', 'HookToCallback']
class CallbackToHook(tfv1.train.SessionRunHook):
""" This is only for internal implementation of
"""
Hooks are less powerful than callbacks so the conversion is incomplete.
It only converts the `before_run/after_run` calls.
This is only for internal implementation of
before_run/after_run callbacks.
You shouldn't need to use this.
"""
......
......@@ -174,7 +174,8 @@ class SaverRestoreRelaxed(SaverRestore):
def f(reader, name, v):
val = reader.get_tensor(name)
SessionUpdate.load_value_to_var(v, val)
v.load(SessionUpdate.relaxed_value_for_var(val, v))
with sess.as_default():
self._match_vars(f)
......
......@@ -47,30 +47,32 @@ class SessionUpdate(object):
self.name_map = {v.name: v for v in vars_to_update}
@staticmethod
def load_value_to_var(var, val, strict=False):
def relaxed_value_for_var(value, var):
"""
Call `var.load(val)` with the default session, with some type checks.
Returns a relaxed (possibly reshaped/upcast-ed) version of value,
to be loaded to the given variable.
Args:
value (ndarray): an numpy array to be loaded to var
var (tf.Variable):
strict (bool): Behave less strict if set to False.
Returns:
ndarray: a possibly reshaped or casted version of value
"""
if strict:
var.load(val)
return
assert isinstance(var, tf.Variable)
name = var.op.name
# check incompatible shape
varshape = tuple(var.get_shape().as_list())
if varshape != val.shape:
if varshape != value.shape:
# TODO only allow reshape when shape different by empty axis
if np.prod(varshape) != np.prod(val.shape):
if np.prod(varshape) != np.prod(value.shape):
raise ValueError(
"Trying to load a tensor of shape {} into the variable '{}' whose shape is {}.".format(
val.shape, name, varshape))
value.shape, name, varshape))
logger.warn("The tensor is reshaped from {} to {} when assigned to '{}'".format(
val.shape, varshape, name))
val = val.reshape(varshape)
value.shape, varshape, name))
value = value.reshape(varshape)
# fix some common type incompatibility problems, but not all
def upcast(vartype, valtype):
......@@ -81,20 +83,17 @@ class SessionUpdate(object):
return np.int64 if vartype == tf.int64 else np.int32
return None
if hasattr(val, 'dtype'):
if hasattr(value, 'dtype'):
vartype = var.value().dtype
if vartype != val.dtype:
msg = "Variable {} has dtype {} but was given a value of dtype {}.".format(name, vartype, val.dtype)
newtype = upcast(var.dtype.base_dtype, val.dtype)
if vartype != value.dtype:
msg = "Variable {} has dtype {} but was given a value of dtype {}.".format(name, vartype, value.dtype)
newtype = upcast(var.dtype.base_dtype, value.dtype)
if newtype is not None:
val = newtype(val)
value = newtype(value)
logger.warn(msg + " Load it after casting!")
else:
assert vartype == val.dtype, msg
try:
var.load(val)
except tf.errors.InvalidArgumentError:
logger.exc("Cannot load this value to the variable {}".format(name))
assert vartype == value.dtype, msg
return value
def update(self, prms):
"""
......@@ -103,10 +102,15 @@ class SessionUpdate(object):
Any name in prms must be in the graph and in vars_to_update.
"""
with self.sess.as_default():
fetches = []
feeds = {}
for name, value in six.iteritems(prms):
assert name in self.name_map
v = self.name_map[name]
SessionUpdate.load_value_to_var(v, value)
var = self.name_map[name]
fetches.append(var.initializer)
# This is the implementation of `var.load`
feeds[var.initializer.inputs[1]] = SessionUpdate.relaxed_value_for_var(value, var)
self.sess.run(fetches, feed_dict=feeds)
def dump_session_params(path):
......
......@@ -40,11 +40,13 @@ def timed_operation(msg, log_start=False):
Good stuff finished, time:1sec.
"""
assert len(msg)
if log_start:
logger.info('Start {} ...'.format(msg))
start = timer()
yield
logger.info('{} finished, time:{:.4f}sec.'.format(
msg = msg[0].upper() + msg[1:]
logger.info('{} finished, time:{:.4f} sec.'.format(
msg, timer() - start))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment