Commit 236c78e0 authored by Yuxin Wu's avatar Yuxin Wu

better save/restore with towerp & batch_norm

parent 178f3611
## DisturbLabel ## DisturbLabel
I ran into the paper [DisturbLabel: Regularizing CNN on the Loss Layer](https://arxiv.org/abs/1605.00055) on CVPR16. I ran into the paper [DisturbLabel: Regularizing CNN on the Loss Layer](https://arxiv.org/abs/1605.00055) on CVPR16,
which basically said that noisy data gives you better performance.
As many, I didn't believe the method and the results at first. As many, I didn't believe the method and the results at first.
This is a simple mnist training script with DisturbLabel. It uses the architecture in the paper and This is a simple mnist training script with DisturbLabel. It uses the architecture in the paper and
......
...@@ -43,6 +43,10 @@ def get_dorefa(bitW, bitA, bitG): ...@@ -43,6 +43,10 @@ def get_dorefa(bitW, bitA, bitG):
def fw(x): def fw(x):
if bitW == 32: if bitW == 32:
return x return x
if bitW == 1: # BWN
with G.gradient_override_map({"Sign": "Identity"}):
E = tf.stop_gradient(tf.reduce_mean(tf.abs(x)))
return tf.sign(x / E) * E
x = tf.tanh(x) x = tf.tanh(x)
x = x / tf.reduce_max(tf.abs(x)) * 0.5 + 0.5 x = x / tf.reduce_max(tf.abs(x)) * 0.5 + 0.5
return 2 * quantize(x, bitW) - 1 return 2 * quantize(x, bitW) - 1
......
...@@ -3,4 +3,4 @@ scipy ...@@ -3,4 +3,4 @@ scipy
nltk nltk
h5py h5py
pyzmq pyzmq
tornado tornado; python_version < '3.0'
...@@ -64,20 +64,20 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -64,20 +64,20 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
assert not use_local_stat assert not use_local_stat
with tf.name_scope(None): with tf.name_scope(None):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname) ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var]) mean_var_name = ema.average_name(batch_mean) + ':0'
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) var_var_name = ema.average_name(batch_var) + ':0'
G = tf.get_default_graph() G = tf.get_default_graph()
# find training statistics in training tower # find training statistics in training tower
try: try:
mean_name = re.sub('towerp[0-9]+/', '', ema_mean.name) mean_name = re.sub('towerp[0-9]+/', '', mean_var_name)
var_name = re.sub('towerp[0-9]+/', '', ema_var.name) var_name = re.sub('towerp[0-9]+/', '', var_var_name)
#var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0' #var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0'
ema_mean = G.get_tensor_by_name(mean_name) ema_mean = G.get_tensor_by_name(mean_name)
ema_var = G.get_tensor_by_name(var_name) ema_var = G.get_tensor_by_name(var_name)
except KeyError: except KeyError:
mean_name = re.sub('towerp[0-9]+/', 'tower0/', ema_mean.name) mean_name = re.sub('towerp[0-9]+/', 'tower0/', mean_var_name)
var_name = re.sub('towerp[0-9]+/', 'tower0/', ema_var.name) var_name = re.sub('towerp[0-9]+/', 'tower0/', var_var_name)
ema_mean = G.get_tensor_by_name(mean_name) ema_mean = G.get_tensor_by_name(mean_name)
ema_var = G.get_tensor_by_name(var_name) ema_var = G.get_tensor_by_name(var_name)
#logger.info("In prediction, using {} instead of {} for {}".format( #logger.info("In prediction, using {} instead of {} for {}".format(
......
...@@ -94,7 +94,11 @@ class SaverRestore(SessionInit): ...@@ -94,7 +94,11 @@ class SaverRestore(SessionInit):
@staticmethod @staticmethod
def _read_checkpoint_vars(model_path): def _read_checkpoint_vars(model_path):
reader = tf.train.NewCheckpointReader(model_path) reader = tf.train.NewCheckpointReader(model_path)
return set(reader.get_variable_to_shape_map().keys()) ckpt_vars = reader.get_variable_to_shape_map().keys()
for v in ckpt_vars:
if v.startswith('towerp'):
logger.warn("Found {} in checkpoint. Anything from prediction tower shouldn't be saved.".format(v.name))
return set(ckpt_vars)
@staticmethod @staticmethod
def _get_vars_to_restore_multimap(vars_available): def _get_vars_to_restore_multimap(vars_available):
...@@ -102,13 +106,14 @@ class SaverRestore(SessionInit): ...@@ -102,13 +106,14 @@ class SaverRestore(SessionInit):
Get a dict of {var_name: [var, var]} to restore Get a dict of {var_name: [var, var]} to restore
:param vars_available: varaibles available in the checkpoint, for existence checking :param vars_available: varaibles available in the checkpoint, for existence checking
""" """
# TODO warn if some variable in checkpoint is not used
vars_to_restore = tf.all_variables() vars_to_restore = tf.all_variables()
var_dict = defaultdict(list) var_dict = defaultdict(list)
for v in vars_to_restore: for v in vars_to_restore:
name = v.op.name name = v.op.name
if 'towerp' in name: if 'towerp' in name:
logger.warn("Anything from prediction tower shouldn't be saved.") logger.warn("Variable {} in prediction tower shouldn't exist.".format(v.name))
# don't overwrite anything in the current prediction graph
continue
if 'tower' in name: if 'tower' in name:
new_name = re.sub('tower[p0-9]+/', '', name) new_name = re.sub('tower[p0-9]+/', '', name)
name = new_name name = new_name
...@@ -117,6 +122,7 @@ class SaverRestore(SessionInit): ...@@ -117,6 +122,7 @@ class SaverRestore(SessionInit):
vars_available.remove(name) vars_available.remove(name)
else: else:
logger.warn("Param {} not found in checkpoint! Will not restore.".format(v.op.name)) logger.warn("Param {} not found in checkpoint! Will not restore.".format(v.op.name))
# TODO warn if some variable in checkpoint is not used
#for name in vars_available: #for name in vars_available:
#logger.warn("Param {} in checkpoint doesn't appear in the graph!".format(name)) #logger.warn("Param {} in checkpoint doesn't appear in the graph!".format(name))
return var_dict return var_dict
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment