Commit 236c78e0 authored by Yuxin Wu's avatar Yuxin Wu

better save/restore with towerp & batch_norm

parent 178f3611
## DisturbLabel
I ran into the paper [DisturbLabel: Regularizing CNN on the Loss Layer](https://arxiv.org/abs/1605.00055) on CVPR16.
I ran into the paper [DisturbLabel: Regularizing CNN on the Loss Layer](https://arxiv.org/abs/1605.00055) on CVPR16,
which basically said that noisy data gives you better performance.
As many, I didn't believe the method and the results at first.
This is a simple mnist training script with DisturbLabel. It uses the architecture in the paper and
......
......@@ -43,6 +43,10 @@ def get_dorefa(bitW, bitA, bitG):
def fw(x):
if bitW == 32:
return x
if bitW == 1: # BWN
with G.gradient_override_map({"Sign": "Identity"}):
E = tf.stop_gradient(tf.reduce_mean(tf.abs(x)))
return tf.sign(x / E) * E
x = tf.tanh(x)
x = x / tf.reduce_max(tf.abs(x)) * 0.5 + 0.5
return 2 * quantize(x, bitW) - 1
......
......@@ -3,4 +3,4 @@ scipy
nltk
h5py
pyzmq
tornado
tornado; python_version < '3.0'
......@@ -64,20 +64,20 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
assert not use_local_stat
with tf.name_scope(None):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
mean_var_name = ema.average_name(batch_mean) + ':0'
var_var_name = ema.average_name(batch_var) + ':0'
G = tf.get_default_graph()
# find training statistics in training tower
try:
mean_name = re.sub('towerp[0-9]+/', '', ema_mean.name)
var_name = re.sub('towerp[0-9]+/', '', ema_var.name)
mean_name = re.sub('towerp[0-9]+/', '', mean_var_name)
var_name = re.sub('towerp[0-9]+/', '', var_var_name)
#var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0'
ema_mean = G.get_tensor_by_name(mean_name)
ema_var = G.get_tensor_by_name(var_name)
except KeyError:
mean_name = re.sub('towerp[0-9]+/', 'tower0/', ema_mean.name)
var_name = re.sub('towerp[0-9]+/', 'tower0/', ema_var.name)
mean_name = re.sub('towerp[0-9]+/', 'tower0/', mean_var_name)
var_name = re.sub('towerp[0-9]+/', 'tower0/', var_var_name)
ema_mean = G.get_tensor_by_name(mean_name)
ema_var = G.get_tensor_by_name(var_name)
#logger.info("In prediction, using {} instead of {} for {}".format(
......
......@@ -94,7 +94,11 @@ class SaverRestore(SessionInit):
@staticmethod
def _read_checkpoint_vars(model_path):
reader = tf.train.NewCheckpointReader(model_path)
return set(reader.get_variable_to_shape_map().keys())
ckpt_vars = reader.get_variable_to_shape_map().keys()
for v in ckpt_vars:
if v.startswith('towerp'):
logger.warn("Found {} in checkpoint. Anything from prediction tower shouldn't be saved.".format(v.name))
return set(ckpt_vars)
@staticmethod
def _get_vars_to_restore_multimap(vars_available):
......@@ -102,13 +106,14 @@ class SaverRestore(SessionInit):
Get a dict of {var_name: [var, var]} to restore
:param vars_available: varaibles available in the checkpoint, for existence checking
"""
# TODO warn if some variable in checkpoint is not used
vars_to_restore = tf.all_variables()
var_dict = defaultdict(list)
for v in vars_to_restore:
name = v.op.name
if 'towerp' in name:
logger.warn("Anything from prediction tower shouldn't be saved.")
logger.warn("Variable {} in prediction tower shouldn't exist.".format(v.name))
# don't overwrite anything in the current prediction graph
continue
if 'tower' in name:
new_name = re.sub('tower[p0-9]+/', '', name)
name = new_name
......@@ -117,6 +122,7 @@ class SaverRestore(SessionInit):
vars_available.remove(name)
else:
logger.warn("Param {} not found in checkpoint! Will not restore.".format(v.op.name))
# TODO warn if some variable in checkpoint is not used
#for name in vars_available:
#logger.warn("Param {} in checkpoint doesn't appear in the graph!".format(name))
return var_dict
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment