Commit e2cc7058 authored by Yuxin Wu's avatar Yuxin Wu

remove deprecated impl in BN. catch more exception in saver

parent 41122718
......@@ -73,8 +73,9 @@ class ModelSaver(Callback):
global_step=tf.train.get_global_step(),
write_meta_graph=False)
logger.info("Model saved to %s." % tf.train.get_checkpoint_state(self.checkpoint_dir).model_checkpoint_path)
except (OSError, IOError): # disk error sometimes.. just ignore it
logger.exception("Exception in ModelSaver.trigger_epoch!")
except (OSError, IOError, tf.errors.PermissionDeniedError,
tf.errors.ResourceExhaustedError): # disk error sometimes.. just ignore it
logger.exception("Exception in ModelSaver!")
class MinSaver(Callback):
......
......@@ -17,85 +17,6 @@ __all__ = ['BatchNorm', 'BatchRenorm']
# eps: torch: 1e-5. Lasagne: 1e-4
# XXX This is deprecated. Only kept for future reference.
@layer_register(log_shape=False)
def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
shape = x.get_shape().as_list()
assert len(shape) in [2, 4]
n_out = shape[-1] # channel
assert n_out is not None
beta = tf.get_variable('beta', [n_out],
initializer=tf.constant_initializer())
gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(1.0))
if len(shape) == 2:
batch_mean, batch_var = tf.nn.moments(x, [0], keep_dims=False)
else:
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=False)
# just to make a clear name.
batch_mean = tf.identity(batch_mean, 'mean')
batch_var = tf.identity(batch_var, 'variance')
emaname = 'EMA'
ctx = get_current_tower_context()
if use_local_stat is None:
use_local_stat = ctx.is_training
if use_local_stat != ctx.is_training:
logger.warn("[BatchNorm] use_local_stat != is_training")
if use_local_stat:
# training tower
if ctx.is_training:
# reuse = tf.get_variable_scope().reuse
with tf.variable_scope(tf.get_variable_scope(), reuse=False):
# BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
# if reuse=True, try to find and use the existing statistics
# how to use multiple tensors to update one EMA? seems impossbile
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
if ctx.is_main_training_tower:
# inside main training tower
add_model_variable(ema_mean)
add_model_variable(ema_var)
else:
# no apply() is called here, no magic vars will get created,
# no reuse issue will happen
assert not ctx.is_training
with tf.name_scope(None):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
mean_var_name = ema.average_name(batch_mean)
var_var_name = ema.average_name(batch_var)
if ctx.is_main_tower:
# main tower, but needs to use global stat. global stat must be from outside
# when reuse=True, the desired variable name could
# actually be different, because a different var is created
# for different reuse tower
ema_mean = tf.get_variable('mean/' + emaname, [n_out])
ema_var = tf.get_variable('variance/' + emaname, [n_out])
else:
# use statistics in another tower
G = tf.get_default_graph()
ema_mean = ctx.find_tensor_in_main_tower(G, mean_var_name + ':0')
ema_var = ctx.find_tensor_in_main_tower(G, var_var_name + ':0')
if use_local_stat:
batch = tf.cast(tf.shape(x)[0], tf.float32)
mul = tf.where(tf.equal(batch, 1.0), 1.0, batch / (batch - 1))
batch_var = batch_var * mul # use unbiased variance estimator in training
with tf.control_dependencies([ema_apply_op] if ctx.is_training else []):
# only apply EMA op if is_training
return tf.nn.batch_normalization(
x, batch_mean, batch_var, beta, gamma, epsilon, 'output')
else:
return tf.nn.batch_normalization(
x, ema_mean, ema_var, beta, gamma, epsilon, 'output')
def get_bn_variables(n_out, use_scale, use_bias, gamma_init):
if use_bias:
beta = tf.get_variable('beta', [n_out], initializer=tf.constant_initializer())
......@@ -310,7 +231,7 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon)
if ctx.is_main_training_tower:
if ctx.is_main_training_tower or ctx.has_own_variables:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else:
ret = tf.identity(xn, name='output')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment