Commit 78595e71 authored by Yuxin Wu's avatar Yuxin Wu

significantly improves speed of DictRestore

Loading a large COCO model takes 50 sec -> 0.4 sec
parent 5a868442
...@@ -15,6 +15,9 @@ from config import config as cfg ...@@ -15,6 +15,9 @@ from config import config as cfg
@layer_register(log_shape=True) @layer_register(log_shape=True)
def GroupNorm(x, group=32, gamma_initializer=tf.constant_initializer(1.)): def GroupNorm(x, group=32, gamma_initializer=tf.constant_initializer(1.)):
"""
More code that reproduces the paper can be found at https://github.com/ppwwyyxx/GroupNorm-reproduce/.
"""
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
ndims = len(shape) ndims = len(shape)
assert ndims == 4, shape assert ndims == 4, shape
......
...@@ -17,6 +17,7 @@ from imagenet_utils import ImageNetModel, fbresnet_augmentor, get_imagenet_dataf ...@@ -17,6 +17,7 @@ from imagenet_utils import ImageNetModel, fbresnet_augmentor, get_imagenet_dataf
def GroupNorm(x, group, gamma_initializer=tf.constant_initializer(1.)): def GroupNorm(x, group, gamma_initializer=tf.constant_initializer(1.)):
""" """
https://arxiv.org/abs/1803.08494 https://arxiv.org/abs/1803.08494
More code that reproduces the paper can be found at https://github.com/ppwwyyxx/GroupNorm-reproduce/.
""" """
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
ndims = len(shape) ndims = len(shape)
......
...@@ -13,9 +13,13 @@ __all__ = ['CallbackToHook', 'HookToCallback'] ...@@ -13,9 +13,13 @@ __all__ = ['CallbackToHook', 'HookToCallback']
class CallbackToHook(tfv1.train.SessionRunHook): class CallbackToHook(tfv1.train.SessionRunHook):
""" This is only for internal implementation of """
before_run/after_run callbacks. Hooks are less powerful than callbacks so the conversion is incomplete.
You shouldn't need to use this. It only converts the `before_run/after_run` calls.
This is only for internal implementation of
before_run/after_run callbacks.
You shouldn't need to use this.
""" """
def __init__(self, cb): def __init__(self, cb):
......
...@@ -174,7 +174,8 @@ class SaverRestoreRelaxed(SaverRestore): ...@@ -174,7 +174,8 @@ class SaverRestoreRelaxed(SaverRestore):
def f(reader, name, v): def f(reader, name, v):
val = reader.get_tensor(name) val = reader.get_tensor(name)
SessionUpdate.load_value_to_var(v, val) v.load(SessionUpdate.relaxed_value_for_var(val, v))
with sess.as_default(): with sess.as_default():
self._match_vars(f) self._match_vars(f)
......
...@@ -47,30 +47,32 @@ class SessionUpdate(object): ...@@ -47,30 +47,32 @@ class SessionUpdate(object):
self.name_map = {v.name: v for v in vars_to_update} self.name_map = {v.name: v for v in vars_to_update}
@staticmethod @staticmethod
def load_value_to_var(var, val, strict=False): def relaxed_value_for_var(value, var):
""" """
Call `var.load(val)` with the default session, with some type checks. Returns a relaxed (possibly reshaped/upcast-ed) version of value,
to be loaded to the given variable.
Args: Args:
value (ndarray): an numpy array to be loaded to var
var (tf.Variable): var (tf.Variable):
strict (bool): Behave less strict if set to False.
Returns:
ndarray: a possibly reshaped or casted version of value
""" """
if strict: assert isinstance(var, tf.Variable)
var.load(val)
return
name = var.op.name name = var.op.name
# check incompatible shape # check incompatible shape
varshape = tuple(var.get_shape().as_list()) varshape = tuple(var.get_shape().as_list())
if varshape != val.shape: if varshape != value.shape:
# TODO only allow reshape when shape different by empty axis # TODO only allow reshape when shape different by empty axis
if np.prod(varshape) != np.prod(val.shape): if np.prod(varshape) != np.prod(value.shape):
raise ValueError( raise ValueError(
"Trying to load a tensor of shape {} into the variable '{}' whose shape is {}.".format( "Trying to load a tensor of shape {} into the variable '{}' whose shape is {}.".format(
val.shape, name, varshape)) value.shape, name, varshape))
logger.warn("The tensor is reshaped from {} to {} when assigned to '{}'".format( logger.warn("The tensor is reshaped from {} to {} when assigned to '{}'".format(
val.shape, varshape, name)) value.shape, varshape, name))
val = val.reshape(varshape) value = value.reshape(varshape)
# fix some common type incompatibility problems, but not all # fix some common type incompatibility problems, but not all
def upcast(vartype, valtype): def upcast(vartype, valtype):
...@@ -81,20 +83,17 @@ class SessionUpdate(object): ...@@ -81,20 +83,17 @@ class SessionUpdate(object):
return np.int64 if vartype == tf.int64 else np.int32 return np.int64 if vartype == tf.int64 else np.int32
return None return None
if hasattr(val, 'dtype'): if hasattr(value, 'dtype'):
vartype = var.value().dtype vartype = var.value().dtype
if vartype != val.dtype: if vartype != value.dtype:
msg = "Variable {} has dtype {} but was given a value of dtype {}.".format(name, vartype, val.dtype) msg = "Variable {} has dtype {} but was given a value of dtype {}.".format(name, vartype, value.dtype)
newtype = upcast(var.dtype.base_dtype, val.dtype) newtype = upcast(var.dtype.base_dtype, value.dtype)
if newtype is not None: if newtype is not None:
val = newtype(val) value = newtype(value)
logger.warn(msg + " Load it after casting!") logger.warn(msg + " Load it after casting!")
else: else:
assert vartype == val.dtype, msg assert vartype == value.dtype, msg
try: return value
var.load(val)
except tf.errors.InvalidArgumentError:
logger.exc("Cannot load this value to the variable {}".format(name))
def update(self, prms): def update(self, prms):
""" """
...@@ -103,10 +102,15 @@ class SessionUpdate(object): ...@@ -103,10 +102,15 @@ class SessionUpdate(object):
Any name in prms must be in the graph and in vars_to_update. Any name in prms must be in the graph and in vars_to_update.
""" """
with self.sess.as_default(): with self.sess.as_default():
fetches = []
feeds = {}
for name, value in six.iteritems(prms): for name, value in six.iteritems(prms):
assert name in self.name_map assert name in self.name_map
v = self.name_map[name] var = self.name_map[name]
SessionUpdate.load_value_to_var(v, value) fetches.append(var.initializer)
# This is the implementation of `var.load`
feeds[var.initializer.inputs[1]] = SessionUpdate.relaxed_value_for_var(value, var)
self.sess.run(fetches, feed_dict=feeds)
def dump_session_params(path): def dump_session_params(path):
......
...@@ -40,11 +40,13 @@ def timed_operation(msg, log_start=False): ...@@ -40,11 +40,13 @@ def timed_operation(msg, log_start=False):
Good stuff finished, time:1sec. Good stuff finished, time:1sec.
""" """
assert len(msg)
if log_start: if log_start:
logger.info('Start {} ...'.format(msg)) logger.info('Start {} ...'.format(msg))
start = timer() start = timer()
yield yield
logger.info('{} finished, time:{:.4f}sec.'.format( msg = msg[0].upper() + msg[1:]
logger.info('{} finished, time:{:.4f} sec.'.format(
msg, timer() - start)) msg, timer() - start))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment